Skip to content

Commit

Permalink
Simplify conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 20, 2025
1 parent cb9052d commit 4c2c4db
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
22 changes: 22 additions & 0 deletions test/test_check_arities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
TestFunction,
TrialFunction,
adjoint,
as_tensor,
cofac,
conditional,
conj,
derivative,
ds,
Expand Down Expand Up @@ -84,3 +86,23 @@ def test_product_arity():
with pytest.raises(ArityMismatch):
L = inner(v, v) * dx
compute_form_data(L, complex_mode=False)


def test_zero_simplify_arity():
cell = tetrahedron
D = Mesh(FiniteElement("Lagrange", cell, 1, (3,), identity_pullback, H1))
V = FunctionSpace(D, FiniteElement("Lagrange", cell, 2, (), identity_pullback, H1))
v = TestFunction(V)
u = Coefficient(V)

zero = as_tensor([0, u])[0]
F = inner(u, v + zero) * dx
compute_form_data(F, complex_mode=False)

zero = conditional(u < 0, 0, 0)
F = inner(u, v + zero) * dx
compute_form_data(F, complex_mode=False)

zero = conditional(u < 0, 0, conditional(u == 0, 0, 0))
F = inner(u, v + zero) * dx
compute_form_data(F, complex_mode=False)
17 changes: 15 additions & 2 deletions ufl/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,23 @@ class Conditional(Operator):
In C++ these take the format `(condition ? true_value : false_value)`.
"""

__slots__ = ()
__slots__ = ("_initialised",)

def __new__(cls, condition, true_value, false_value):
"""Create a new Conditional."""
# Simplify
if true_value == false_value:
return true_value
# Construct a new instance to be initialised
self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, condition, true_value, false_value):
"""Initialise."""
if self._initialised:
return
# Checks
if not isinstance(condition, Condition):
raise ValueError("Expecting condition as first argument.")
true_value = as_ufl(true_value)
Expand All @@ -290,8 +303,8 @@ def __init__(self, condition, true_value, false_value):
)
):
raise ValueError("Non-scalar == or != is not allowed.")

Operator.__init__(self, (condition, true_value, false_value))
self._initialised = True

def evaluate(self, x, mapping, component, index_values):
"""Evaluate."""
Expand Down

0 comments on commit 4c2c4db

Please sign in to comment.