diff --git a/test/test_change_to_local.py b/test/test_change_to_local.py index 973195bc7..807bf1645 100755 --- a/test/test_change_to_local.py +++ b/test/test_change_to_local.py @@ -12,8 +12,8 @@ def test_change_to_reference_grad(): cell = triangle domain = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1)) - U = FunctionSpace(domain, FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)) - V = FunctionSpace(domain, FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1)) + U = FunctionSpace(domain, FiniteElement("Lagrange", cell, 3, (), identity_pullback, H1)) + V = FunctionSpace(domain, FiniteElement("Lagrange", cell, 3, (2,), identity_pullback, H1)) u = Coefficient(U) v = Coefficient(V) Jinv = JacobianInverse(domain) diff --git a/ufl/checks.py b/ufl/checks.py index 09ae4453b..31dfe18b5 100644 --- a/ufl/checks.py +++ b/ufl/checks.py @@ -11,8 +11,9 @@ from ufl.core.expr import Expr from ufl.core.terminal import FormArgument from ufl.corealg.traversal import traverse_unique_terminals -from ufl.geometry import GeometricQuantity +from ufl.geometry import GeometricQuantity, SpatialCoordinate from ufl.sobolevspace import H1 +from ufl.domain import extract_unique_domain def is_python_scalar(expression): @@ -34,7 +35,20 @@ def is_true_ufl_scalar(expression): def is_cellwise_constant(expr): """Return whether expression is constant over a single cell.""" - # TODO: Implement more accurately considering e.g. derivatives? + from ufl.coefficient import Coefficient + from ufl.differentiation import ReferenceGrad + + if isinstance(expr, ReferenceGrad): + (expr,) = expr.ufl_operands + if is_cellwise_constant(expr): + return True + elif isinstance(expr, SpatialCoordinate): + domain = extract_unique_domain(expr) + return domain.is_piecewise_linear_simplex_domain() + elif isinstance(expr, Coefficient): + element = expr.ufl_element() + return element.embedded_superdegree <= 1 + return all(e.is_cellwise_constant() for e in traverse_unique_terminals(expr)) diff --git a/ufl/core/expr.py b/ufl/core/expr.py index 41b6e55a6..368c6cba3 100644 --- a/ufl/core/expr.py +++ b/ufl/core/expr.py @@ -266,7 +266,6 @@ def ufl_domain(self): DeprecationWarning, ) from ufl.domain import extract_unique_domain - return extract_unique_domain(self) # --- Functions for float evaluation ---