Skip to content

Commit

Permalink
collect domains in GeometricQuantities
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Mar 31, 2024
1 parent 2e89778 commit 0b1b007
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 25 deletions.
14 changes: 7 additions & 7 deletions test/test_external_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_differentiation_procedure_action(V1, V2):


def test_extractions(domain_2d, V1):
from ufl.algorithms.analysis import (extract_arguments, extract_arguments_and_coefficients,
from ufl.algorithms.analysis import (extract_arguments, extract_arguments_and_coefficients_and_geometric_quantities,
extract_base_form_operators, extract_coefficients, extract_constants)

u = Coefficient(V1)
Expand All @@ -192,15 +192,15 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e) == [u]
assert extract_arguments(e) == [vstar_e]
assert extract_arguments_and_coefficients(e) == ([vstar_e], [u])
assert extract_arguments_and_coefficients_and_geometric_quantities(e) == ([vstar_e], [u], [])
assert extract_constants(e) == [c]
assert extract_base_form_operators(e) == [e]

F = e * dx

assert extract_coefficients(F) == [u]
assert extract_arguments(e) == [vstar_e]
assert extract_arguments_and_coefficients(e) == ([vstar_e], [u])
assert extract_arguments_and_coefficients_and_geometric_quantities(e) == ([vstar_e], [u], [])
assert extract_constants(F) == [c]
assert F.base_form_operators() == (e,)

Expand All @@ -209,14 +209,14 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e) == [u]
assert extract_arguments(e) == [vstar_e, u_hat]
assert extract_arguments_and_coefficients(e) == ([vstar_e, u_hat], [u])
assert extract_arguments_and_coefficients_and_geometric_quantities(e) == ([vstar_e, u_hat], [u], [])
assert extract_base_form_operators(e) == [e]

F = e * dx

assert extract_coefficients(F) == [u]
assert extract_arguments(e) == [vstar_e, u_hat]
assert extract_arguments_and_coefficients(e) == ([vstar_e, u_hat], [u])
assert extract_arguments_and_coefficients_and_geometric_quantities(e) == ([vstar_e, u_hat], [u], [])
assert F.base_form_operators() == (e,)

w = Coefficient(V1)
Expand All @@ -225,14 +225,14 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e2) == [u, w]
assert extract_arguments(e2) == [vstar_e2, u_hat]
assert extract_arguments_and_coefficients(e2) == ([vstar_e2, u_hat], [u, w])
assert extract_arguments_and_coefficients_and_geometric_quantities(e2) == ([vstar_e2, u_hat], [u, w], [])
assert extract_base_form_operators(e2) == [e, e2]

F = e2 * dx

assert extract_coefficients(e2) == [u, w]
assert extract_arguments(e2) == [vstar_e2, u_hat]
assert extract_arguments_and_coefficients(e2) == ([vstar_e2, u_hat], [u, w])
assert extract_arguments_and_coefficients_and_geometric_quantities(e2) == ([vstar_e2, u_hat], [u, w], [])
assert F.base_form_operators() == (e, e2)


Expand Down
8 changes: 4 additions & 4 deletions test/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ufl import (Action, Adjoint, Argument, Coefficient, FunctionSpace, Mesh, TestFunction, TrialFunction, action,
adjoint, derivative, dx, grad, inner, replace, triangle)
from ufl.algorithms.ad import expand_derivatives
from ufl.algorithms.analysis import (extract_arguments, extract_arguments_and_coefficients, extract_base_form_operators,
from ufl.algorithms.analysis import (extract_arguments, extract_arguments_and_coefficients_and_geometric_quantities, extract_base_form_operators,
extract_coefficients)
from ufl.algorithms.expand_indices import expand_indices
from ufl.core.interpolate import Interpolate
Expand Down Expand Up @@ -141,12 +141,12 @@ def test_extract_base_form_operators(V1, V2):
# -- Interpolate(u, V2) -- #
Iu = Interpolate(u, V2)
assert extract_arguments(Iu) == [vstar]
assert extract_arguments_and_coefficients(Iu) == ([vstar], [u])
assert extract_arguments_and_coefficients_and_geometric_quantities(Iu) == ([vstar], [u], [])

F = Iu * dx
# Form composition: Iu * dx <=> Action(v * dx, Iu(u; v*))
assert extract_arguments(F) == []
assert extract_arguments_and_coefficients(F) == ([], [u])
assert extract_arguments_and_coefficients_and_geometric_quantities(F) == ([], [u], [])

for e in [Iu, F]:
assert extract_coefficients(e) == [u]
Expand All @@ -155,7 +155,7 @@ def test_extract_base_form_operators(V1, V2):
# -- Interpolate(u, V2) -- #
Iv = Interpolate(uhat, V2)
assert extract_arguments(Iv) == [vstar, uhat]
assert extract_arguments_and_coefficients(Iv) == ([vstar, uhat], [])
assert extract_arguments_and_coefficients_and_geometric_quantities(Iv) == ([vstar, uhat], [], [])
assert extract_coefficients(Iv) == []
assert extract_base_form_operators(Iv) == [Iv]

Expand Down
4 changes: 2 additions & 2 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def _get_action_form_arguments(left, right):
elif isinstance(right, CoefficientDerivative):
# Action differentiation pushes differentiation through
# right as a consequence of Leibniz formula.
from ufl.algorithms.analysis import extract_arguments_and_coefficients
right_args, right_coeffs = extract_arguments_and_coefficients(right)
from ufl.algorithms.analysis import extract_arguments_and_coefficients_and_geometric_quantities
right_args, right_coeffs, _ = extract_arguments_and_coefficients_and_geometric_quantities(right)
arguments = left_args + tuple(right_args)
coefficients += tuple(right_coeffs)
elif isinstance(right, (BaseCoefficient, Zero)):
Expand Down
34 changes: 24 additions & 10 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from itertools import chain

from ufl.algorithms.traversal import iter_expressions
from ufl.domain import Mesh
from ufl.argument import BaseArgument, Coargument
from ufl.coefficient import BaseCoefficient
from ufl.constant import Constant
from ufl.geometry import GeometricQuantity
from ufl.core.base_form_operator import BaseFormOperator
from ufl.core.terminal import Terminal
from ufl.corealg.traversal import traverse_unique_terminals, unique_pre_traversal
Expand Down Expand Up @@ -187,19 +189,20 @@ def extract_base_form_operators(a):
return sorted_by_count(extract_type(a, BaseFormOperator))


def extract_arguments_and_coefficients(a):
"""Build two sorted lists of all arguments and coefficients in a.
def extract_arguments_and_coefficients_and_geometric_quantities(a):
"""Build three sorted lists of all arguments, coefficients, and geometric quantities in a.
This function is faster than extract_arguments + extract_coefficients
This function is faster than extract_arguments + extract_coefficients + extract_geometric_quantities
for large forms, and has more validation built in.
Args:
a: A BaseForm, Integral or Expr
"""
# Extract lists of all BaseArgument and BaseCoefficient instances
base_coeff_and_args = extract_type(a, (BaseArgument, BaseCoefficient))
arguments = [f for f in base_coeff_and_args if isinstance(f, BaseArgument)]
coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)]
base_coeff_and_args_and_gq = extract_type(a, (BaseArgument, BaseCoefficient, GeometricQuantity))
arguments = [f for f in base_coeff_and_args_and_gq if isinstance(f, BaseArgument)]
coefficients = [f for f in base_coeff_and_args_and_gq if isinstance(f, BaseCoefficient)]
geometric_quantities = [f for f in base_coeff_and_args_and_gq if isinstance(f, GeometricQuantity)]

# Build number,part: instance mappings, should be one to one
bfnp = dict((f, (f.number(), f.part())) for f in arguments)
Expand All @@ -214,19 +217,30 @@ def extract_arguments_and_coefficients(a):
if len(fcounts) != len(set(fcounts.values())):
raise ValueError(
"Found different coefficients with same counts.\n"
"The arguments found are:\n" + "\n".join(f" {c}" for c in coefficients))
"The Coefficients found are:\n" + "\n".join(f" {c}" for c in coefficients))

# Build count: instance mappings, should be one to one
gqcounts = {}
for gq in geometric_quantities:
assert isinstance(gq._domain, Mesh), f"Found that {gq}._domain is {gq._domain}"
gqcounts[gq] = (type(gq).name, gq._domain._ufl_id)
if len(gqcounts) != len(set(gqcounts.values())):
raise ValueError(
"Found different geometric quantities with same (geometric_quantity_type, domain).\n"
"The GeometricQuantities found are:\n" + "\n".join(f" {gq}" for gq in geometric_quantities))

# Passed checks, so we can safely sort the instances by count
arguments = _sorted_by_number_and_part(arguments)
coefficients = sorted_by_count(coefficients)
geometric_quantities = list(sorted(geometric_quantities, key=lambda gq: (type(gq).name, gq._domain._ufl_id)))

return arguments, coefficients
return arguments, coefficients, geometric_quantities


def extract_elements(form):
"""Build sorted tuple of all elements used in form."""
args = chain(*extract_arguments_and_coefficients(form))
return tuple(f.ufl_element() for f in args)
arguments, coefficients, _ = extract_arguments_and_coefficients_and_geometric_quantities(form)
return tuple(f.ufl_element() for f in arguments + coefficients)


def extract_unique_elements(form):
Expand Down
12 changes: 10 additions & 2 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class Form(BaseForm):
"_coefficient_numbering",
"_constants",
"_constant_numbering",
"_geometric_quantities",
"_terminal_numbering",
"_hash",
"_signature",
Expand Down Expand Up @@ -625,14 +626,15 @@ def _analyze_subdomain_data(self):

def _analyze_form_arguments(self):
"""Analyze which Argument and Coefficient objects can be found in the form."""
from ufl.algorithms.analysis import extract_arguments_and_coefficients
arguments, coefficients = extract_arguments_and_coefficients(self)
from ufl.algorithms.analysis import extract_arguments_and_coefficients_and_geometric_quantities
arguments, coefficients, geometric_quantities = extract_arguments_and_coefficients_and_geometric_quantities(self)

# Define canonical numbering of arguments and coefficients
self._arguments = tuple(
sorted(set(arguments), key=lambda x: x.number()))
self._coefficients = tuple(
sorted(set(coefficients), key=lambda x: x.count()))
self._geometric_quantities = geometric_quantities # sorted by (type, domain)

def _analyze_base_form_operators(self):
"""Analyze which BaseFormOperator objects can be found in the form."""
Expand Down Expand Up @@ -674,6 +676,12 @@ def _compute_renumbering(self):
renumbering[d] = k
k += 1

for gq in self._geometric_quantities:
d = gq._domain
if d not in renumbering:
renumbering[d] = k
k += 1

return renumbering

def _compute_signature(self):
Expand Down

0 comments on commit 0b1b007

Please sign in to comment.