Skip to content

Commit

Permalink
introduce MeshSequence
Browse files Browse the repository at this point in the history
Co-authored-by: Jørgen Schartum Dokken <[email protected]>
  • Loading branch information
ksagiyam and jorgensd committed Jan 15, 2025
1 parent cb9052d commit 3fe26ca
Show file tree
Hide file tree
Showing 17 changed files with 592 additions and 141 deletions.
14 changes: 7 additions & 7 deletions test/test_external_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def test_differentiation_procedure_action(V1, V2):
def test_extractions(domain_2d, V1):
from ufl.algorithms.analysis import (
extract_arguments,
extract_arguments_and_coefficients,
extract_base_form_operators,
extract_coefficients,
extract_constants,
extract_terminals_with_domain,
)

u = Coefficient(V1)
Expand All @@ -219,15 +219,15 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e) == [u]
assert extract_arguments(e) == [vstar_e]
assert extract_arguments_and_coefficients(e) == ([vstar_e], [u])
assert extract_terminals_with_domain(e) == ([vstar_e], [u], [])
assert extract_constants(e) == [c]
assert extract_base_form_operators(e) == [e]

F = e * dx

assert extract_coefficients(F) == [u]
assert extract_arguments(e) == [vstar_e]
assert extract_arguments_and_coefficients(e) == ([vstar_e], [u])
assert extract_terminals_with_domain(e) == ([vstar_e], [u], [])
assert extract_constants(F) == [c]
assert F.base_form_operators() == (e,)

Expand All @@ -236,14 +236,14 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e) == [u]
assert extract_arguments(e) == [vstar_e, u_hat]
assert extract_arguments_and_coefficients(e) == ([vstar_e, u_hat], [u])
assert extract_terminals_with_domain(e) == ([vstar_e, u_hat], [u], [])
assert extract_base_form_operators(e) == [e]

F = e * dx

assert extract_coefficients(F) == [u]
assert extract_arguments(e) == [vstar_e, u_hat]
assert extract_arguments_and_coefficients(e) == ([vstar_e, u_hat], [u])
assert extract_terminals_with_domain(e) == ([vstar_e, u_hat], [u], [])
assert F.base_form_operators() == (e,)

w = Coefficient(V1)
Expand All @@ -252,14 +252,14 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e2) == [u, w]
assert extract_arguments(e2) == [vstar_e2, u_hat]
assert extract_arguments_and_coefficients(e2) == ([vstar_e2, u_hat], [u, w])
assert extract_terminals_with_domain(e2) == ([vstar_e2, u_hat], [u, w], [])
assert extract_base_form_operators(e2) == [e, e2]

F = e2 * dx

assert extract_coefficients(e2) == [u, w]
assert extract_arguments(e2) == [vstar_e2, u_hat]
assert extract_arguments_and_coefficients(e2) == ([vstar_e2, u_hat], [u, w])
assert extract_terminals_with_domain(e2) == ([vstar_e2, u_hat], [u, w], [])
assert F.base_form_operators() == (e, e2)


Expand Down
8 changes: 4 additions & 4 deletions test/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from ufl.algorithms.ad import expand_derivatives
from ufl.algorithms.analysis import (
extract_arguments,
extract_arguments_and_coefficients,
extract_base_form_operators,
extract_coefficients,
extract_terminals_with_domain,
)
from ufl.algorithms.expand_indices import expand_indices
from ufl.core.interpolate import Interpolate
Expand Down Expand Up @@ -157,12 +157,12 @@ def test_extract_base_form_operators(V1, V2):
# -- Interpolate(u, V2) -- #
Iu = Interpolate(u, V2)
assert extract_arguments(Iu) == [vstar]
assert extract_arguments_and_coefficients(Iu) == ([vstar], [u])
assert extract_terminals_with_domain(Iu) == ([vstar], [u], [])

F = Iu * dx
# Form composition: Iu * dx <=> Action(v * dx, Iu(u; v*))
assert extract_arguments(F) == []
assert extract_arguments_and_coefficients(F) == ([], [u])
assert extract_terminals_with_domain(F) == ([], [u], [])

for e in [Iu, F]:
assert extract_coefficients(e) == [u]
Expand All @@ -171,7 +171,7 @@ def test_extract_base_form_operators(V1, V2):
# -- Interpolate(u, V2) -- #
Iv = Interpolate(uhat, V2)
assert extract_arguments(Iv) == [vstar, uhat]
assert extract_arguments_and_coefficients(Iv) == ([vstar, uhat], [])
assert extract_terminals_with_domain(Iv) == ([vstar, uhat], [], [])
assert extract_coefficients(Iv) == []
assert extract_base_form_operators(Iv) == [Iv]

Expand Down
104 changes: 104 additions & 0 deletions test/test_mixed_function_space_with_mesh_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from ufl import (
CellVolume,
Coefficient,
FacetArea,
FacetNormal,
FunctionSpace,
Measure,
Mesh,
MeshSequence,
SpatialCoordinate,
TestFunction,
TrialFunction,
grad,
inner,
split,
triangle,
)
from ufl.algorithms import compute_form_data
from ufl.domain import extract_domains
from ufl.finiteelement import FiniteElement, MixedElement
from ufl.pullback import contravariant_piola, identity_pullback
from ufl.sobolevspace import H1, L2, HDiv


def test_mixed_function_space_with_mesh_sequence_basic():
cell = triangle
elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2,), contravariant_piola, HDiv)
elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2)
elem = MixedElement([elem0, elem1, elem2])
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=100)
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=101)
mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=102)
domain = MeshSequence([mesh0, mesh1, mesh2])
V = FunctionSpace(domain, elem)
u = TrialFunction(V)
v = TestFunction(V)
f = Coefficient(V, count=1000)
g = Coefficient(V, count=2000)
u0, u1, u2 = split(u)
v0, v1, v2 = split(v)
f0, f1, f2 = split(f)
g0, g1, g2 = split(g)
dx1 = Measure("dx", mesh1)
x = SpatialCoordinate(mesh1)
form = x[1] * f0 * inner(grad(u0), v1) * dx1(999)
fd = compute_form_data(
form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(CellVolume, FacetArea),
do_apply_restrictions=True,
do_estimate_degrees=True,
complex_mode=False,
)
(id0,) = fd.integral_data
assert fd.preprocessed_form.arguments() == (v, u)
assert fd.reduced_coefficients == [f]
assert form.coefficients()[fd.original_coefficient_positions[0]] is f
assert id0.domain is mesh1
assert id0.integral_type == "cell"
assert id0.subdomain_id == (999,)
assert fd.original_form.domain_numbering()[id0.domain] == 0
assert id0.integral_coefficients == set([f])
assert id0.enabled_coefficients == [True]


def test_mixed_function_space_with_mesh_sequence_signature():
cell = triangle
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=100)
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=101)
dx0 = Measure("dx", mesh0)
dx1 = Measure("dx", mesh1)
n0 = FacetNormal(mesh0)
n1 = FacetNormal(mesh1)
form_a = inner(n1, n1) * dx0(999)
form_b = inner(n0, n0) * dx1(999)
assert form_a.signature() == form_b.signature()
assert extract_domains(form_a) == (mesh0, mesh1)
assert extract_domains(form_b) == (mesh1, mesh0)


def test_mixed_function_space_with_mesh_sequence_hash():
cell = triangle
elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2,), contravariant_piola, HDiv)
elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2)
elem = MixedElement([elem0, elem1, elem2])
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=100)
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=101)
mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1), ufl_id=102)
domain = MeshSequence([mesh0, mesh1, mesh2])
domain_ = MeshSequence([mesh0, mesh1, mesh2])
V = FunctionSpace(domain, elem)
V_ = FunctionSpace(domain_, elem)
u = TrialFunction(V)
u_ = TrialFunction(V_)
assert hash(domain_) == hash(domain)
assert domain_ == domain
assert hash(V_) == hash(V)
assert V_ == V
assert hash(u_) == hash(u)
assert u_ == u
4 changes: 3 additions & 1 deletion ufl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
-AbstractDomain
-Mesh
-MeshSequence
-MeshView
* Sobolev spaces::
Expand Down Expand Up @@ -265,7 +266,7 @@
from ufl.core.external_operator import ExternalOperator
from ufl.core.interpolate import Interpolate, interpolate
from ufl.core.multiindex import Index, indices
from ufl.domain import AbstractDomain, Mesh, MeshView
from ufl.domain import AbstractDomain, Mesh, MeshSequence, MeshView
from ufl.finiteelement import AbstractFiniteElement
from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm
from ufl.formoperators import (
Expand Down Expand Up @@ -484,6 +485,7 @@
"MaxFacetEdgeLength",
"Measure",
"Mesh",
"MeshSequence",
"MeshView",
"MinCellEdgeLength",
"MinFacetEdgeLength",
Expand Down
4 changes: 2 additions & 2 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def _get_action_form_arguments(left, right):
elif isinstance(right, CoefficientDerivative):
# Action differentiation pushes differentiation through
# right as a consequence of Leibniz formula.
from ufl.algorithms.analysis import extract_arguments_and_coefficients
from ufl.algorithms.analysis import extract_terminals_with_domain

right_args, right_coeffs = extract_arguments_and_coefficients(right)
right_args, right_coeffs, _ = extract_terminals_with_domain(right)
arguments = left_args + tuple(right_args)
coefficients += tuple(right_coeffs)
elif isinstance(right, (BaseCoefficient, Zero)):
Expand Down
47 changes: 35 additions & 12 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from ufl.core.base_form_operator import BaseFormOperator
from ufl.core.terminal import Terminal
from ufl.corealg.traversal import traverse_unique_terminals, unique_pre_traversal
from ufl.domain import Mesh
from ufl.form import BaseForm, Form
from ufl.geometry import GeometricQuantity
from ufl.utils.sorting import sorted_by_count, topological_sorting

# TODO: Some of these can possibly be optimised by implementing
Expand Down Expand Up @@ -190,19 +192,24 @@ def extract_base_form_operators(a):
return sorted_by_count(extract_type(a, BaseFormOperator))


def extract_arguments_and_coefficients(a):
"""Build two sorted lists of all arguments and coefficients in a.
def extract_terminals_with_domain(a):
"""Build three sorted lists of all arguments, coefficients, and geometric quantities in `a`.
This function is faster than extract_arguments + extract_coefficients
for large forms, and has more validation built in.
This function is faster than extracting each type of terminal
separately for large forms, and has more validation built in.
Args:
a: A BaseForm, Integral or Expr
Returns:
Tuples of extracted `Argument`s, `Coefficient`s, and `GeometricQuantity`s.
"""
# Extract lists of all BaseArgument and BaseCoefficient instances
base_coeff_and_args = extract_type(a, (BaseArgument, BaseCoefficient))
arguments = [f for f in base_coeff_and_args if isinstance(f, BaseArgument)]
coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)]
# Extract lists of all BaseArgument, BaseCoefficient, and GeometricQuantity instances
terminals = extract_type(a, (BaseArgument, BaseCoefficient, GeometricQuantity))
arguments = [f for f in terminals if isinstance(f, BaseArgument)]
coefficients = [f for f in terminals if isinstance(f, BaseCoefficient)]
geometric_quantities = [f for f in terminals if isinstance(f, GeometricQuantity)]

# Build number,part: instance mappings, should be one to one
bfnp = {f: (f.number(), f.part()) for f in arguments}
Expand All @@ -218,20 +225,36 @@ def extract_arguments_and_coefficients(a):
if len(fcounts) != len(set(fcounts.values())):
raise ValueError(
"Found different coefficients with same counts.\n"
"The arguments found are:\n" + "\n".join(f" {c}" for c in coefficients)
"The Coefficients found are:\n" + "\n".join(f" {c}" for c in coefficients)
)

# Build count: instance mappings, should be one to one
gqcounts = {}
for gq in geometric_quantities:
if not isinstance(gq._domain, Mesh):
raise TypeError(f"{gq}._domain must be a Mesh: got {gq._domain}")
gqcounts[gq] = (type(gq).name, gq._domain._ufl_id)
if len(gqcounts) != len(set(gqcounts.values())):
raise ValueError(
"Found different geometric quantities with same (geometric_quantity_type, domain).\n"
"The GeometricQuantities found are:\n"
"\n".join(f" {gq}" for gq in geometric_quantities)
)

# Passed checks, so we can safely sort the instances by count
arguments = _sorted_by_number_and_part(arguments)
coefficients = sorted_by_count(coefficients)
geometric_quantities = list(
sorted(geometric_quantities, key=lambda gq: (type(gq).name, gq._domain._ufl_id))
)

return arguments, coefficients
return arguments, coefficients, geometric_quantities


def extract_elements(form):
"""Build sorted tuple of all elements used in form."""
args = chain(*extract_arguments_and_coefficients(form))
return tuple(f.ufl_element() for f in args)
arguments, coefficients, _ = extract_terminals_with_domain(form)
return tuple(f.ufl_element() for f in arguments + coefficients)


def extract_unique_elements(form):
Expand Down
Loading

0 comments on commit 3fe26ca

Please sign in to comment.