diff --git a/ufl/action.py b/ufl/action.py index 470e85016..b2f0a7ce3 100644 --- a/ufl/action.py +++ b/ufl/action.py @@ -9,15 +9,14 @@ from itertools import chain +from ufl import matrix # noqa 401 from ufl.algebra import Sum from ufl.argument import Argument, Coargument -from ufl.coefficient import BaseCoefficient, Coefficient, Cofunction +from ufl.coefficient import BaseCoefficient, Coefficient from ufl.constantvalue import Zero -from ufl.core.base_form_operator import BaseFormOperator from ufl.core.ufl_type import ufl_type from ufl.differentiation import CoefficientDerivative from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm -from ufl.matrix import Matrix # --- The Action class represents the action of a numerical object that needs # to be computed at assembly time --- @@ -159,20 +158,13 @@ def __hash__(self): def _check_function_spaces(left, right): """Check if the function spaces of left and right match.""" + # Action differentiation pushes differentiation through + # right as a consequence of Leibniz formula. if isinstance(right, CoefficientDerivative): - # Action differentiation pushes differentiation through - # right as a consequence of Leibniz formula. right, *_ = right.ufl_operands + if isinstance(left, CoefficientDerivative): + left, *_ = left.ufl_operands - # `left` can also be a Coefficient in V (= V**), e.g. - # `action(Coefficient(V), Cofunction(V.dual()))`. - left_arg = left.arguments()[-1] if not isinstance(left, Coefficient) else left - if isinstance(right, (Form, Action, Matrix, ZeroBaseForm)): - if left_arg.ufl_function_space().dual() != right.arguments()[0].ufl_function_space(): - raise TypeError("Incompatible function spaces in Action") - elif isinstance(right, (Coefficient, Cofunction, Argument, BaseFormOperator)): - if left_arg.ufl_function_space() != right.ufl_function_space(): - raise TypeError("Incompatible function spaces in Action") # `Zero` doesn't contain any information about the function space. # -> Not a problem since Action will get simplified with a # `ZeroBaseForm` which won't take into account the arguments on @@ -180,8 +172,22 @@ def _check_function_spaces(left, right): # This occurs for: # `derivative(Action(A, B), u)` with B is an `Expr` such that dB/du == 0 # -> `derivative(B, u)` becomes `Zero` when expanding derivatives since B is an Expr. - elif not isinstance(right, Zero): - raise TypeError("Incompatible argument in Action: %s" % type(right)) + if isinstance(left, Zero) or isinstance(right, Zero): + return + + # `left` can also be a Coefficient in V (= V**), e.g. + # `action(Coefficient(V), Cofunction(V.dual()))`. + if isinstance(left, Coefficient): + V_left = left.ufl_function_space() + else: + V_left = left.arguments()[-1].ufl_function_space().dual() + if isinstance(right, Coefficient): + V_right = right.ufl_function_space() + else: + V_right = right.arguments()[0].ufl_function_space().dual() + + if V_left.dual() != V_right: + raise TypeError("Incompatible function spaces in Action") def _get_action_form_arguments(left, right):