Skip to content

Commit

Permalink
introduce MeshSequence
Browse files Browse the repository at this point in the history
Co-authored-by: Jørgen Schartum Dokken <[email protected]>
  • Loading branch information
ksagiyam and jorgensd committed Jan 10, 2025
1 parent 6051fab commit 057ef1b
Show file tree
Hide file tree
Showing 16 changed files with 488 additions and 141 deletions.
14 changes: 7 additions & 7 deletions test/test_external_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def test_differentiation_procedure_action(V1, V2):
def test_extractions(domain_2d, V1):
from ufl.algorithms.analysis import (
extract_arguments,
extract_arguments_and_coefficients,
extract_base_form_operators,
extract_coefficients,
extract_constants,
extract_terminals_with_domain,
)

u = Coefficient(V1)
Expand All @@ -219,15 +219,15 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e) == [u]
assert extract_arguments(e) == [vstar_e]
assert extract_arguments_and_coefficients(e) == ([vstar_e], [u])
assert extract_terminals_with_domain(e) == ([vstar_e], [u], [])
assert extract_constants(e) == [c]
assert extract_base_form_operators(e) == [e]

F = e * dx

assert extract_coefficients(F) == [u]
assert extract_arguments(e) == [vstar_e]
assert extract_arguments_and_coefficients(e) == ([vstar_e], [u])
assert extract_terminals_with_domain(e) == ([vstar_e], [u], [])
assert extract_constants(F) == [c]
assert F.base_form_operators() == (e,)

Expand All @@ -236,14 +236,14 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e) == [u]
assert extract_arguments(e) == [vstar_e, u_hat]
assert extract_arguments_and_coefficients(e) == ([vstar_e, u_hat], [u])
assert extract_terminals_with_domain(e) == ([vstar_e, u_hat], [u], [])
assert extract_base_form_operators(e) == [e]

F = e * dx

assert extract_coefficients(F) == [u]
assert extract_arguments(e) == [vstar_e, u_hat]
assert extract_arguments_and_coefficients(e) == ([vstar_e, u_hat], [u])
assert extract_terminals_with_domain(e) == ([vstar_e, u_hat], [u], [])
assert F.base_form_operators() == (e,)

w = Coefficient(V1)
Expand All @@ -252,14 +252,14 @@ def test_extractions(domain_2d, V1):

assert extract_coefficients(e2) == [u, w]
assert extract_arguments(e2) == [vstar_e2, u_hat]
assert extract_arguments_and_coefficients(e2) == ([vstar_e2, u_hat], [u, w])
assert extract_terminals_with_domain(e2) == ([vstar_e2, u_hat], [u, w], [])
assert extract_base_form_operators(e2) == [e, e2]

F = e2 * dx

assert extract_coefficients(e2) == [u, w]
assert extract_arguments(e2) == [vstar_e2, u_hat]
assert extract_arguments_and_coefficients(e2) == ([vstar_e2, u_hat], [u, w])
assert extract_terminals_with_domain(e2) == ([vstar_e2, u_hat], [u, w], [])
assert F.base_form_operators() == (e, e2)


Expand Down
8 changes: 4 additions & 4 deletions test/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from ufl.algorithms.ad import expand_derivatives
from ufl.algorithms.analysis import (
extract_arguments,
extract_arguments_and_coefficients,
extract_base_form_operators,
extract_coefficients,
extract_terminals_with_domain,
)
from ufl.algorithms.expand_indices import expand_indices
from ufl.core.interpolate import Interpolate
Expand Down Expand Up @@ -157,12 +157,12 @@ def test_extract_base_form_operators(V1, V2):
# -- Interpolate(u, V2) -- #
Iu = Interpolate(u, V2)
assert extract_arguments(Iu) == [vstar]
assert extract_arguments_and_coefficients(Iu) == ([vstar], [u])
assert extract_terminals_with_domain(Iu) == ([vstar], [u], [])

F = Iu * dx
# Form composition: Iu * dx <=> Action(v * dx, Iu(u; v*))
assert extract_arguments(F) == []
assert extract_arguments_and_coefficients(F) == ([], [u])
assert extract_terminals_with_domain(F) == ([], [u], [])

for e in [Iu, F]:
assert extract_coefficients(e) == [u]
Expand All @@ -171,7 +171,7 @@ def test_extract_base_form_operators(V1, V2):
# -- Interpolate(u, V2) -- #
Iv = Interpolate(uhat, V2)
assert extract_arguments(Iv) == [vstar, uhat]
assert extract_arguments_and_coefficients(Iv) == ([vstar, uhat], [])
assert extract_terminals_with_domain(Iv) == ([vstar, uhat], [], [])
assert extract_coefficients(Iv) == []
assert extract_base_form_operators(Iv) == [Iv]

Expand Down
4 changes: 3 additions & 1 deletion ufl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
-AbstractDomain
-Mesh
-MeshSequence
-MeshView
* Sobolev spaces::
Expand Down Expand Up @@ -265,7 +266,7 @@
from ufl.core.external_operator import ExternalOperator
from ufl.core.interpolate import Interpolate, interpolate
from ufl.core.multiindex import Index, indices
from ufl.domain import AbstractDomain, Mesh, MeshView
from ufl.domain import AbstractDomain, Mesh, MeshSequence, MeshView
from ufl.finiteelement import AbstractFiniteElement
from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm
from ufl.formoperators import (
Expand Down Expand Up @@ -484,6 +485,7 @@
"MaxFacetEdgeLength",
"Measure",
"Mesh",
"MeshSequence",
"MeshView",
"MinCellEdgeLength",
"MinFacetEdgeLength",
Expand Down
4 changes: 2 additions & 2 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def _get_action_form_arguments(left, right):
elif isinstance(right, CoefficientDerivative):
# Action differentiation pushes differentiation through
# right as a consequence of Leibniz formula.
from ufl.algorithms.analysis import extract_arguments_and_coefficients
from ufl.algorithms.analysis import extract_terminals_with_domain

right_args, right_coeffs = extract_arguments_and_coefficients(right)
right_args, right_coeffs, _ = extract_terminals_with_domain(right)
arguments = left_args + tuple(right_args)
coefficients += tuple(right_coeffs)
elif isinstance(right, (BaseCoefficient, Zero)):
Expand Down
47 changes: 35 additions & 12 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from ufl.core.base_form_operator import BaseFormOperator
from ufl.core.terminal import Terminal
from ufl.corealg.traversal import traverse_unique_terminals, unique_pre_traversal
from ufl.domain import Mesh
from ufl.form import BaseForm, Form
from ufl.geometry import GeometricQuantity
from ufl.utils.sorting import sorted_by_count, topological_sorting

# TODO: Some of these can possibly be optimised by implementing
Expand Down Expand Up @@ -198,19 +200,24 @@ def extract_base_form_operators(a):
return sorted_by_count(extract_type(a, BaseFormOperator))


def extract_arguments_and_coefficients(a):
"""Build two sorted lists of all arguments and coefficients in a.
def extract_terminals_with_domain(a):
"""Build three sorted lists of all arguments, coefficients, and geometric quantities in `a`.
This function is faster than extract_arguments + extract_coefficients
for large forms, and has more validation built in.
This function is faster than extracting each type of terminal
separately for large forms, and has more validation built in.
Args:
a: A BaseForm, Integral or Expr
Returns:
Tuples of extracted `Argument`s, `Coefficient`s, and `GeometricQuantity`s.
"""
# Extract lists of all BaseArgument and BaseCoefficient instances
base_coeff_and_args = extract_type(a, (BaseArgument, BaseCoefficient))
arguments = [f for f in base_coeff_and_args if isinstance(f, BaseArgument)]
coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)]
# Extract lists of all BaseArgument, BaseCoefficient, and GeometricQuantity instances
terminals = extract_type(a, (BaseArgument, BaseCoefficient, GeometricQuantity))
arguments = [f for f in terminals if isinstance(f, BaseArgument)]
coefficients = [f for f in terminals if isinstance(f, BaseCoefficient)]
geometric_quantities = [f for f in terminals if isinstance(f, GeometricQuantity)]

# Build number,part: instance mappings, should be one to one
bfnp = dict((f, (f.number(), f.part())) for f in arguments)
Expand All @@ -226,20 +233,36 @@ def extract_arguments_and_coefficients(a):
if len(fcounts) != len(set(fcounts.values())):
raise ValueError(
"Found different coefficients with same counts.\n"
"The arguments found are:\n" + "\n".join(f" {c}" for c in coefficients)
"The Coefficients found are:\n" + "\n".join(f" {c}" for c in coefficients)
)

# Build count: instance mappings, should be one to one
gqcounts = {}
for gq in geometric_quantities:
if not isinstance(gq._domain, Mesh):
raise TypeError(f"{gq}._domain must be a Mesh: got {gq._domain}")
gqcounts[gq] = (type(gq).name, gq._domain._ufl_id)
if len(gqcounts) != len(set(gqcounts.values())):
raise ValueError(
"Found different geometric quantities with same (geometric_quantity_type, domain).\n"
"The GeometricQuantities found are:\n"
"\n".join(f" {gq}" for gq in geometric_quantities)
)

# Passed checks, so we can safely sort the instances by count
arguments = _sorted_by_number_and_part(arguments)
coefficients = sorted_by_count(coefficients)
geometric_quantities = list(
sorted(geometric_quantities, key=lambda gq: (type(gq).name, gq._domain._ufl_id))
)

return arguments, coefficients
return arguments, coefficients, geometric_quantities


def extract_elements(form):
"""Build sorted tuple of all elements used in form."""
args = chain(*extract_arguments_and_coefficients(form))
return tuple(f.ufl_element() for f in args)
arguments, coefficients, _ = extract_terminals_with_domain(form)
return tuple(f.ufl_element() for f in arguments + coefficients)


def extract_unique_elements(form):
Expand Down
131 changes: 118 additions & 13 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from collections import defaultdict
from math import pi

import numpy as np

from ufl.action import Action
from ufl.algorithms.analysis import extract_arguments
from ufl.algorithms.map_integrands import map_integrand_dags
Expand Down Expand Up @@ -55,7 +57,7 @@
BaseFormOperatorDerivative,
CoordinateDerivative,
)
from ufl.domain import extract_unique_domain
from ufl.domain import MeshSequence, extract_unique_domain
from ufl.form import Form, ZeroBaseForm
from ufl.operators import (
bessel_I,
Expand Down Expand Up @@ -84,6 +86,27 @@
# - ReferenceDivRuleset


def flatten_domain_element(domain, element):
"""Return the flattened (domain, element) pairs for mixed domain problems.
Args:
domain: `Mesh` or `MeshSequence`.
element: `FiniteElement`.
Returns:
Nested tuples of (domain, element) pairs; just ((domain, element),)
if domain is a `Mesh` (and not a `MeshSequence`).
"""
if not isinstance(domain, MeshSequence):
return ((domain, element),)
flattened = ()
assert len(domain) == len(element.sub_elements)
for d, e in zip(domain, element.sub_elements):
flattened += flatten_domain_element(d, e)
return flattened


class GenericDerivativeRuleset(MultiFunction):
"""A generic derivative."""

Expand Down Expand Up @@ -657,16 +680,58 @@ def reference_value(self, o):
"""Differentiate a reference_value."""
# grad(o) == grad(rv(f)) -> K_ji*rgrad(rv(f))_rj
f = o.ufl_operands[0]
if isinstance(f.ufl_element().pullback, PhysicalPullback):
# TODO: Do we need to be more careful for immersed things?
return ReferenceGrad(o)

if not f._ufl_is_terminal_:
raise ValueError("ReferenceValue can only wrap a terminal")
domain = extract_unique_domain(f)
K = JacobianInverse(domain)
Do = grad_to_reference_grad(o, K)
return Do
domain = extract_unique_domain(f, expand_mixed_mesh=False)
if isinstance(domain, MeshSequence):
element = f.ufl_function_space().ufl_element()
if element.num_sub_elements != len(domain):
raise RuntimeError(f"{element.num_sub_elements} != {len(domain)}")
# Get monolithic representation of rgrad(o); o might live in a mixed space.
rgrad = ReferenceGrad(o)
ref_dim = rgrad.ufl_shape[-1]
# Apply K_ji(d) to the corresponding components of rgrad, store them in a list,
# and put them back together at the end using as_tensor().
components = []
dofoffset = 0
for d, e in flatten_domain_element(domain, element):
esh = e.reference_value_shape
ndof = int(np.prod(esh))
assert ndof > 0
if isinstance(e.pullback, PhysicalPullback):
if ref_dim != self._var_shape[0]:
raise NotImplementedError("""
PhysicalPullback not handled for immersed domain :
reference dim ({ref_dim}) != physical dim (self._var_shape[0])""")
for idx in range(ndof):
for i in range(ref_dim):
components.append(rgrad[(dofoffset + idx,) + (i,)])
else:
K = JacobianInverse(d)
rdim, gdim = K.ufl_shape
if rdim != ref_dim:
raise RuntimeError(f"{rdim} != {ref_dim}")
if gdim != self._var_shape[0]:
raise RuntimeError(f"{gdim} != {self._var_shape[0]}")
# Note that rgrad[dofoffset + [0,ndof), [0,rdim)] are the components
# corresponding to (d, e).
# For each row, rgrad[dofoffset + idx, [0,rdim)], we apply
# K_ji(d)[[0,rdim), [0,gdim)].
for idx in range(ndof):
for i in range(gdim):
temp = Zero()
for j in range(rdim):
temp += rgrad[(dofoffset + idx,) + (j,)] * K[j, i]
components.append(temp)
dofoffset += ndof
return as_tensor(np.asarray(components).reshape(rgrad.ufl_shape[:-1] + self._var_shape))
else:
if isinstance(f.ufl_element().pullback, PhysicalPullback):
# TODO: Do we need to be more careful for immersed things?
return ReferenceGrad(o)
else:
K = JacobianInverse(domain)
return grad_to_reference_grad(o, K)

def reference_grad(self, o):
"""Differentiate a reference_grad."""
Expand All @@ -678,10 +743,50 @@ def reference_grad(self, o):
)
if not valid_operand:
raise ValueError("ReferenceGrad can only wrap a reference frame type!")
domain = extract_unique_domain(f)
K = JacobianInverse(domain)
Do = grad_to_reference_grad(o, K)
return Do
domain = extract_unique_domain(f, expand_mixed_mesh=False)
if isinstance(domain, MeshSequence):
if not f._ufl_is_in_reference_frame_:
raise RuntimeError("Expecting a reference frame type")
while not f._ufl_is_terminal_:
(f,) = f.ufl_operands
element = f.ufl_function_space().ufl_element()
if element.num_sub_elements != len(domain):
raise RuntimeError(f"{element.num_sub_elements} != {len(domain)}")
# Get monolithic representation of rgrad(o); o might live in a mixed space.
rgrad = ReferenceGrad(o)
ref_dim = rgrad.ufl_shape[-1]
# Apply K_ji(d) to the corresponding components of rgrad, store them in a list,
# and put them back together at the end using as_tensor().
components = []
dofoffset = 0
for d, e in flatten_domain_element(domain, element):
esh = e.reference_value_shape
ndof = int(np.prod(esh))
assert ndof > 0
K = JacobianInverse(d)
rdim, gdim = K.ufl_shape
if rdim != ref_dim:
raise RuntimeError(f"{rdim} != {ref_dim}")
if gdim != self._var_shape[0]:
raise RuntimeError(f"{gdim} != {self._var_shape[0]}")
# Note that rgrad[dofoffset + [0,ndof), [0,rdim), [0,rdim)] are the components
# corresponding to (d, e).
# For each row, rgrad[dofoffset + idx, [0,rdim), [0,rdim)], we apply
# K_ji(d)[[0,rdim), [0,gdim)].
for idx in range(ndof):
for midx in np.ndindex(rgrad.ufl_shape[1:-1]):
for i in range(gdim):
temp = Zero()
for j in range(rdim):
temp += rgrad[(dofoffset + idx,) + midx + (j,)] * K[j, i]
components.append(temp)
dofoffset += ndof
if rgrad.ufl_shape[0] != dofoffset:
raise RuntimeError(f"{rgrad.ufl_shape[0]} != {dofoffset}")
return as_tensor(np.asarray(components).reshape(rgrad.ufl_shape[:-1] + self._var_shape))
else:
K = JacobianInverse(domain)
return grad_to_reference_grad(o, K)

# --- Nesting of gradients

Expand Down
Loading

0 comments on commit 057ef1b

Please sign in to comment.