Skip to content

Commit

Permalink
Relax assumption on BaseFormOperator's dual argument slot (FEniCS#283)
Browse files Browse the repository at this point in the history
* Update type check to get shape in replace

* Add test

* Fix ruff

* Relax base form operators type check

* Update BFO methods relying on the dual space argument slot

* Fix Interpolate's function space

* Fix typo: function_space -> ufl_function_space

* Update adjoint numbering

* Fix ruff

* Fix rugg

---------

Co-authored-by: David A. Ham <[email protected]>
  • Loading branch information
nbouziani and dham authored Jul 17, 2024
1 parent e69de17 commit b507f2f
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 18 deletions.
42 changes: 36 additions & 6 deletions test/test_external_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from ufl import (
Action,
Argument,
Coargument,
Coefficient,
Constant,
Form,
FunctionSpace,
Matrix,
Mesh,
TestFunction,
TrialFunction,
Expand All @@ -21,6 +23,7 @@
derivative,
dx,
inner,
replace,
sin,
triangle,
)
Expand Down Expand Up @@ -266,7 +269,7 @@ def get_external_operators(form_base):
elif isinstance(form_base, BaseForm):
return form_base.base_form_operators()
else:
raise ValueError("Expecting FormBase argument!")
raise ValueError("Expecting BaseForm argument!")


def test_adjoint_action_jacobian(V1, V2, V3):
Expand Down Expand Up @@ -339,16 +342,17 @@ def vstar_N(number):
dFdu_adj = adjoint(dFdu)
dFdm_adj = adjoint(dFdm)

assert dFdu_adj.arguments() == (u_hat(n_arg),) + v_F
assert dFdm_adj.arguments() == (m_hat(n_arg),) + v_F
V = v_F[0].ufl_function_space()
assert dFdu_adj.arguments() == (TestFunction(V1), TrialFunction(V))
assert dFdm_adj.arguments() == (TestFunction(V2), TrialFunction(V))

# Action of the adjoint
q = Coefficient(v_F[0].ufl_function_space())
q = Coefficient(V)
action_dFdu_adj = action(dFdu_adj, q)
action_dFdm_adj = action(dFdm_adj, q)

assert action_dFdu_adj.arguments() == (u_hat(n_arg),)
assert action_dFdm_adj.arguments() == (m_hat(n_arg),)
assert action_dFdu_adj.arguments() == (TestFunction(V1),)
assert action_dFdm_adj.arguments() == (TestFunction(V2),)


def test_multiple_external_operators(V1, V2):
Expand Down Expand Up @@ -486,3 +490,29 @@ def test_multiple_external_operators(V1, V2):

dFdu = expand_derivatives(derivative(F, u))
assert dFdu == dFdu_partial + Action(dFdN1_partial, dN1du) + Action(dFdN5_partial, dN5du)


def test_replace(V1):
u = Coefficient(V1, count=0)
N = ExternalOperator(u, function_space=V1)

# dN(u; uhat, v*)
dN = expand_derivatives(derivative(N, u))
vstar, uhat = dN.arguments()
assert isinstance(vstar, Coargument)

# Replace v* by a Form
v = TestFunction(V1)
F = inner(u, v) * dx
G = replace(dN, {vstar: F})

dN_replaced = dN._ufl_expr_reconstruct_(u, argument_slots=(F, uhat))
assert G == dN_replaced

# Replace v* by an Action
M = Matrix(V1, V1)
A = Action(M, u)
G = replace(dN, {vstar: A})

dN_replaced = dN._ufl_expr_reconstruct_(u, argument_slots=(A, uhat))
assert G == dN_replaced
4 changes: 1 addition & 3 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect domains
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)
self._domains = join_domains(chain.from_iterable(e.ufl_domain() for e in self.ufl_operands))

def equals(self, other):
"""Check if two Actions are equal."""
Expand Down
6 changes: 5 additions & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def form(self):

def _analyze_form_arguments(self):
"""The arguments of adjoint are the reverse of the form arguments."""
self._arguments = self._form.arguments()[::-1]
reversed_args = self._form.arguments()[::-1]
# Canonical numbering for arguments that is consistent with other BaseForm objects.
self._arguments = tuple(
type(arg)(arg.ufl_function_space(), number=i) for i, arg in enumerate(reversed_args)
)
self._coefficients = self._form.coefficients()

def _analyze_domains(self):
Expand Down
4 changes: 2 additions & 2 deletions ufl/algorithms/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ufl.algorithms.analysis import has_exact_type
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.classes import CoefficientDerivative, Form
from ufl.classes import BaseForm, CoefficientDerivative
from ufl.constantvalue import as_ufl
from ufl.core.external_operator import ExternalOperator
from ufl.core.interpolate import Interpolate
Expand All @@ -28,7 +28,7 @@ def __init__(self, mapping):
# One can replace Coarguments by 1-Forms
def get_shape(x):
"""Get the shape of an object."""
if isinstance(x, Form):
if isinstance(x, BaseForm):
return x.arguments()[0].ufl_shape
return x.ufl_shape

Expand Down
11 changes: 9 additions & 2 deletions ufl/core/base_form_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,21 @@ def count(self):
@property
def ufl_shape(self):
"""Return the UFL shape of the coefficient.produced by the operator."""
return self.arguments()[0]._ufl_shape
arg, *_ = self.argument_slots()
if isinstance(arg, BaseForm):
arg, *_ = arg.arguments()
return arg._ufl_shape

def ufl_function_space(self):
"""Return the function space associated to the operator.
I.e. return the dual of the base form operator's Coargument.
"""
return self.arguments()[0]._ufl_function_space.dual()
arg, *_ = self.argument_slots()
if isinstance(arg, BaseForm):
arg, *_ = arg.arguments()
return arg._ufl_function_space
return arg._ufl_function_space.dual()

def _ufl_expr_reconstruct_(
self, *operands, function_space=None, derivatives=None, argument_slots=None
Expand Down
12 changes: 8 additions & 4 deletions ufl/core/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
#
# Modified by Nacime Bouziani, 2021-2022

from ufl.action import Action
from ufl.argument import Argument, Coargument
from ufl.coefficient import Cofunction
from ufl.constantvalue import as_ufl
from ufl.core.base_form_operator import BaseFormOperator
from ufl.core.ufl_type import ufl_type
from ufl.duals import is_dual
from ufl.form import Form
from ufl.form import BaseForm, Form
from ufl.functionspace import AbstractFunctionSpace


Expand All @@ -35,7 +36,7 @@ def __init__(self, expr, v):
defined on the dual of the FunctionSpace to interpolate into.
"""
# This check could be more rigorous.
dual_args = (Coargument, Cofunction, Form)
dual_args = (Coargument, Cofunction, Form, Action, BaseFormOperator)

if isinstance(v, AbstractFunctionSpace):
if is_dual(v):
Expand All @@ -53,8 +54,11 @@ def __init__(self, expr, v):
# Reversed order convention
argument_slots = (v, expr)
# Get the primal space (V** = V)
vv = v if not isinstance(v, Form) else v.arguments()[0]
function_space = vv.ufl_function_space().dual()
if isinstance(v, BaseForm):
arg, *_ = v.arguments()
function_space = arg.ufl_function_space()
else:
function_space = v.ufl_function_space().dual()
# Set the operand as `expr` for DAG traversal purpose.
operand = expr
BaseFormOperator.__init__(
Expand Down

0 comments on commit b507f2f

Please sign in to comment.