diff --git a/test/test_external_operator.py b/test/test_external_operator.py index 51f9a7e71..ab996cce1 100644 --- a/test/test_external_operator.py +++ b/test/test_external_operator.py @@ -8,10 +8,12 @@ from ufl import ( Action, Argument, + Coargument, Coefficient, Constant, Form, FunctionSpace, + Matrix, Mesh, TestFunction, TrialFunction, @@ -21,6 +23,7 @@ derivative, dx, inner, + replace, sin, triangle, ) @@ -266,7 +269,7 @@ def get_external_operators(form_base): elif isinstance(form_base, BaseForm): return form_base.base_form_operators() else: - raise ValueError("Expecting FormBase argument!") + raise ValueError("Expecting BaseForm argument!") def test_adjoint_action_jacobian(V1, V2, V3): @@ -339,16 +342,17 @@ def vstar_N(number): dFdu_adj = adjoint(dFdu) dFdm_adj = adjoint(dFdm) - assert dFdu_adj.arguments() == (u_hat(n_arg),) + v_F - assert dFdm_adj.arguments() == (m_hat(n_arg),) + v_F + V = v_F[0].ufl_function_space() + assert dFdu_adj.arguments() == (TestFunction(V1), TrialFunction(V)) + assert dFdm_adj.arguments() == (TestFunction(V2), TrialFunction(V)) # Action of the adjoint - q = Coefficient(v_F[0].ufl_function_space()) + q = Coefficient(V) action_dFdu_adj = action(dFdu_adj, q) action_dFdm_adj = action(dFdm_adj, q) - assert action_dFdu_adj.arguments() == (u_hat(n_arg),) - assert action_dFdm_adj.arguments() == (m_hat(n_arg),) + assert action_dFdu_adj.arguments() == (TestFunction(V1),) + assert action_dFdm_adj.arguments() == (TestFunction(V2),) def test_multiple_external_operators(V1, V2): @@ -486,3 +490,29 @@ def test_multiple_external_operators(V1, V2): dFdu = expand_derivatives(derivative(F, u)) assert dFdu == dFdu_partial + Action(dFdN1_partial, dN1du) + Action(dFdN5_partial, dN5du) + + +def test_replace(V1): + u = Coefficient(V1, count=0) + N = ExternalOperator(u, function_space=V1) + + # dN(u; uhat, v*) + dN = expand_derivatives(derivative(N, u)) + vstar, uhat = dN.arguments() + assert isinstance(vstar, Coargument) + + # Replace v* by a Form + v = TestFunction(V1) + F = inner(u, v) * dx + G = replace(dN, {vstar: F}) + + dN_replaced = dN._ufl_expr_reconstruct_(u, argument_slots=(F, uhat)) + assert G == dN_replaced + + # Replace v* by an Action + M = Matrix(V1, V1) + A = Action(M, u) + G = replace(dN, {vstar: A}) + + dN_replaced = dN._ufl_expr_reconstruct_(u, argument_slots=(A, uhat)) + assert G == dN_replaced diff --git a/ufl/action.py b/ufl/action.py index 3bb9fd533..78489ff6b 100644 --- a/ufl/action.py +++ b/ufl/action.py @@ -122,9 +122,7 @@ def _analyze_domains(self): from ufl.domain import join_domains # Collect domains - self._domains = join_domains( - chain.from_iterable(e.ufl_domains() for e in self.ufl_operands) - ) + self._domains = join_domains(chain.from_iterable(e.ufl_domain() for e in self.ufl_operands)) def equals(self, other): """Check if two Actions are equal.""" diff --git a/ufl/adjoint.py b/ufl/adjoint.py index 92447b985..7c1d5c63f 100644 --- a/ufl/adjoint.py +++ b/ufl/adjoint.py @@ -85,7 +85,11 @@ def form(self): def _analyze_form_arguments(self): """The arguments of adjoint are the reverse of the form arguments.""" - self._arguments = self._form.arguments()[::-1] + reversed_args = self._form.arguments()[::-1] + # Canonical numbering for arguments that is consistent with other BaseForm objects. + self._arguments = tuple( + type(arg)(arg.ufl_function_space(), number=i) for i, arg in enumerate(reversed_args) + ) self._coefficients = self._form.coefficients() def _analyze_domains(self): diff --git a/ufl/algorithms/replace.py b/ufl/algorithms/replace.py index 5f3616949..fbe9cecd0 100644 --- a/ufl/algorithms/replace.py +++ b/ufl/algorithms/replace.py @@ -10,7 +10,7 @@ from ufl.algorithms.analysis import has_exact_type from ufl.algorithms.map_integrands import map_integrand_dags -from ufl.classes import CoefficientDerivative, Form +from ufl.classes import BaseForm, CoefficientDerivative from ufl.constantvalue import as_ufl from ufl.core.external_operator import ExternalOperator from ufl.core.interpolate import Interpolate @@ -28,7 +28,7 @@ def __init__(self, mapping): # One can replace Coarguments by 1-Forms def get_shape(x): """Get the shape of an object.""" - if isinstance(x, Form): + if isinstance(x, BaseForm): return x.arguments()[0].ufl_shape return x.ufl_shape diff --git a/ufl/core/base_form_operator.py b/ufl/core/base_form_operator.py index 90bc23dd0..aae943f75 100644 --- a/ufl/core/base_form_operator.py +++ b/ufl/core/base_form_operator.py @@ -133,14 +133,21 @@ def count(self): @property def ufl_shape(self): """Return the UFL shape of the coefficient.produced by the operator.""" - return self.arguments()[0]._ufl_shape + arg, *_ = self.argument_slots() + if isinstance(arg, BaseForm): + arg, *_ = arg.arguments() + return arg._ufl_shape def ufl_function_space(self): """Return the function space associated to the operator. I.e. return the dual of the base form operator's Coargument. """ - return self.arguments()[0]._ufl_function_space.dual() + arg, *_ = self.argument_slots() + if isinstance(arg, BaseForm): + arg, *_ = arg.arguments() + return arg._ufl_function_space + return arg._ufl_function_space.dual() def _ufl_expr_reconstruct_( self, *operands, function_space=None, derivatives=None, argument_slots=None diff --git a/ufl/core/interpolate.py b/ufl/core/interpolate.py index 57e1506fd..3b5e3ba4d 100644 --- a/ufl/core/interpolate.py +++ b/ufl/core/interpolate.py @@ -8,13 +8,14 @@ # # Modified by Nacime Bouziani, 2021-2022 +from ufl.action import Action from ufl.argument import Argument, Coargument from ufl.coefficient import Cofunction from ufl.constantvalue import as_ufl from ufl.core.base_form_operator import BaseFormOperator from ufl.core.ufl_type import ufl_type from ufl.duals import is_dual -from ufl.form import Form +from ufl.form import BaseForm, Form from ufl.functionspace import AbstractFunctionSpace @@ -35,7 +36,7 @@ def __init__(self, expr, v): defined on the dual of the FunctionSpace to interpolate into. """ # This check could be more rigorous. - dual_args = (Coargument, Cofunction, Form) + dual_args = (Coargument, Cofunction, Form, Action, BaseFormOperator) if isinstance(v, AbstractFunctionSpace): if is_dual(v): @@ -53,8 +54,11 @@ def __init__(self, expr, v): # Reversed order convention argument_slots = (v, expr) # Get the primal space (V** = V) - vv = v if not isinstance(v, Form) else v.arguments()[0] - function_space = vv.ufl_function_space().dual() + if isinstance(v, BaseForm): + arg, *_ = v.arguments() + function_space = arg.ufl_function_space() + else: + function_space = v.ufl_function_space().dual() # Set the operand as `expr` for DAG traversal purpose. operand = expr BaseFormOperator.__init__(