diff --git a/test/test_check_arities.py b/test/test_check_arities.py index e2c32f5f5..5508c02b1 100755 --- a/test/test_check_arities.py +++ b/test/test_check_arities.py @@ -9,7 +9,9 @@ TestFunction, TrialFunction, adjoint, + as_tensor, cofac, + conditional, conj, derivative, ds, @@ -84,3 +86,41 @@ def test_product_arity(): with pytest.raises(ArityMismatch): L = inner(v, v) * dx compute_form_data(L, complex_mode=False) + + +def test_zero_simplify_arity(): + """ + Test that adding verious zero-like expressions to a form is simplified, + such that one can compute form data for the integral. + """ + cell = tetrahedron + D = Mesh(FiniteElement("Lagrange", cell, 1, (3,), identity_pullback, H1)) + V = FunctionSpace(D, FiniteElement("Lagrange", cell, 2, (), identity_pullback, H1)) + v = TestFunction(V) + u = Coefficient(V) + + nonzero = 1 + with pytest.raises(ArityMismatch): + F = inner(u, v + nonzero) * dx + compute_form_data(F) + z = Coefficient(V) + + # Add a Zero-component (rank-0) of a tensor to a rank-1 tensor + zero = as_tensor([0, z])[0] + F = inner(u, v + zero) * dx + fd = compute_form_data(F) + assert fd.num_coefficients == 1 + + # Add a conditional that should have been simplified to zero (rank-0) + # to a rank-1 tensor + zero = conditional(z < 0, 0, 0) + F = inner(u, v + zero) * dx + fd = compute_form_data(F) + assert fd.num_coefficients == 1 + + # Check that nested zero conditionals are simplifed to zero (rank-0) + # and can be added to a rank-1 tensor + zero = conditional(z < 0, 0, conditional(z == 0, 0, 0)) + F = inner(u, v + zero) * dx + fd = compute_form_data(F) + assert fd.num_coefficients == 1 diff --git a/ufl/algorithms/check_arities.py b/ufl/algorithms/check_arities.py index e1d9b366b..d38f53165 100644 --- a/ufl/algorithms/check_arities.py +++ b/ufl/algorithms/check_arities.py @@ -57,7 +57,8 @@ def sum(self, o, a, b): """Apply to sum.""" if a != b: raise ArityMismatch( - f"Adding expressions with non-matching form arguments {_afmt(a)} vs {_afmt(b)}." + f"Adding expressions with non-matching form arguments " + f"{tuple(map(_afmt, a))} vs {tuple(map(_afmt, b))}." ) return a @@ -86,7 +87,7 @@ def product(self, o, a, b): if len(c) != len(a) + len(b) or len(c) != len({x[0] for x in c}): raise ArityMismatch( "Multiplying expressions with overlapping form arguments " - f"{_afmt(a)} vs {_afmt(b)}." + f"{tuple(map(_afmt, a))} vs {tuple(map(_afmt, b))}." ) # It's fine for argument parts to overlap return c @@ -138,7 +139,7 @@ def variable(self, o, f, a): def conditional(self, o, c, a, b): """Apply to conditional.""" if c: - raise ArityMismatch(f"Condition cannot depend on form arguments ({_afmt(a)}).") + raise ArityMismatch("Condition cannot depend on form arguments.") if a and isinstance(o.ufl_operands[2], Zero): # Allow conditional(c, arg, 0) return a @@ -153,7 +154,7 @@ def conditional(self, o, c, a, b): # conditional(c, test, nonzeroconstant) raise ArityMismatch( "Conditional subexpressions with non-matching form arguments " - f"{_afmt(a)} vs {_afmt(b)}." + f"{tuple(map(_afmt, a))} vs {tuple(map(_afmt, b))}." ) def linear_indexed_type(self, o, a, i): diff --git a/ufl/conditional.py b/ufl/conditional.py index 117808c2b..5f802943a 100644 --- a/ufl/conditional.py +++ b/ufl/conditional.py @@ -264,10 +264,23 @@ class Conditional(Operator): In C++ these take the format `(condition ? true_value : false_value)`. """ - __slots__ = () + __slots__ = ("_initialised",) + + def __new__(cls, condition, true_value, false_value): + """Create a new Conditional.""" + # Simplify + if bool(true_value == false_value): + return true_value + # Construct a new instance to be initialised + self = Operator.__new__(cls) + self._initialised = False + return self def __init__(self, condition, true_value, false_value): """Initialise.""" + if self._initialised: + return + # Checks if not isinstance(condition, Condition): raise ValueError("Expecting condition as first argument.") true_value = as_ufl(true_value) @@ -290,8 +303,8 @@ def __init__(self, condition, true_value, false_value): ) ): raise ValueError("Non-scalar == or != is not allowed.") - Operator.__init__(self, (condition, true_value, false_value)) + self._initialised = True def evaluate(self, x, mapping, component, index_values): """Evaluate."""