Skip to content

Commit

Permalink
Simplify Grad(CellwiseConstant)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 22, 2025
1 parent 768f403 commit 83609f8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 27 deletions.
17 changes: 17 additions & 0 deletions test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,20 @@ def test_remove_component_tensors(domain):
fd = compute_form_data(form)

assert "ComponentTensor" not in repr(fd.preprocessed_form)


def test_grad_cellwise_constant(domain):
element = FiniteElement("Lagrange", triangle, 3, (), identity_pullback, H1)
space = FunctionSpace(domain, element)
u = Coefficient(space)

# Applying four derivatives to a cubic should simplify to zero
f = div(grad(div(grad(u))))
form = f * dx

fd = compute_form_data(
form,
do_apply_function_pullbacks=True,
)
assert fd.preprocessed_form.empty()
assert fd.num_coefficients == 0
17 changes: 13 additions & 4 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ufl.action import Action
from ufl.algorithms.analysis import extract_arguments
from ufl.algorithms.estimate_degrees import SumDegreeEstimator
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
from ufl.argument import BaseArgument
Expand Down Expand Up @@ -562,6 +563,14 @@ def __init__(self, geometric_dimension):
"""Initialise."""
GenericDerivativeRuleset.__init__(self, var_shape=(geometric_dimension,))
self._Id = Identity(geometric_dimension)
self.degree_estimator = SumDegreeEstimator(1, {})

def is_cellwise_constant(self, o):
"""More precise checks for cellwise constants."""
if is_cellwise_constant(o):
return True
degree = map_expr_dag(self.degree_estimator, o)
return degree == 0

# --- Specialized rules for geometric quantities

Expand All @@ -572,7 +581,7 @@ def geometric_quantity(self, o):
otherwise transform derivatives to reference derivatives.
Override for specific types if other behaviour is needed.
"""
if is_cellwise_constant(o):
if self.is_cellwise_constant(o):
return self.independent_terminal(o)
else:
domain = extract_unique_domain(o)
Expand All @@ -583,7 +592,7 @@ def geometric_quantity(self, o):
def jacobian_inverse(self, o):
"""Differentiate a jacobian_inverse."""
# grad(K) == K_ji rgrad(K)_rj
if is_cellwise_constant(o):
if self.is_cellwise_constant(o):
return self.independent_terminal(o)
if not o._ufl_is_terminal_:
raise ValueError("ReferenceValue can only wrap a terminal")
Expand Down Expand Up @@ -653,9 +662,10 @@ def reference_value(self, o):

def reference_grad(self, o):
"""Differentiate a reference_grad."""
if self.is_cellwise_constant(o):
return self.independent_terminal(o)
# grad(o) == grad(rgrad(rv(f))) -> K_ji*rgrad(rgrad(rv(f)))_rj
f = o.ufl_operands[0]

valid_operand = f._ufl_is_in_reference_frame_ or isinstance(
f, (JacobianInverse, SpatialCoordinate, Jacobian, JacobianDeterminant)
)
Expand All @@ -676,7 +686,6 @@ def grad(self, o):
# Check that o is a "differential terminal"
if not isinstance(o.ufl_operands[0], (Grad, Terminal)):
raise ValueError("Expecting only grads applied to a terminal.")

return Grad(o)

def _grad(self, o):
Expand Down
37 changes: 14 additions & 23 deletions ufl/algorithms/remove_component_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algorithms.estimate_degrees import SumDegreeEstimator
from ufl.classes import (
ComponentTensor,
Form,
Index,
MultiIndex,
Zero,
)
from ufl.classes import ComponentTensor, Form, Index, MultiIndex, Zero
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction, memoized_handler

Expand All @@ -42,13 +35,10 @@ def zero(self, o):
free_indices = []
index_dimensions = []
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
if Index(i) in self.fimap:
ind_j = self.fimap[Index(i)]
if isinstance(ind_j, Index):
free_indices.append(ind_j.count())
index_dimensions.append(d)
else:
free_indices.append(i)
k = Index(i)
j = self.fimap.get(k, k)
if isinstance(j, Index):
free_indices.append(j.count())
index_dimensions.append(d)
return Zero(
shape=o.ufl_shape,
Expand All @@ -69,26 +59,24 @@ def __init__(self):
"""Initialise."""
MultiFunction.__init__(self)
self._object_cache = {}
self.degree_estimator = SumDegreeEstimator(1, {})

expr = MultiFunction.reuse_if_untouched

@memoized_handler
def reference_grad(self, o):
"""Simplify ReferenceGrad(Constant)."""
def _unary_operator(self, o):
"""Simplify UnaryOperator(Zero)."""
(operand,) = o.ufl_operands
operand = map_expr_dag(self, operand)
degree = map_expr_dag(self.degree_estimator, operand)
if degree == 0:
f = map_expr_dag(self, operand)
if isinstance(f, Zero):
return Zero(
shape=o.ufl_shape,
free_indices=o.ufl_free_indices,
index_dimensions=o.ufl_index_dimensions,
)
if operand is o.ufl_operands[0]:
if f is operand:
# Reuse if untouched
return o
return o._ufl_expr_reconstruct_(operand)
return o._ufl_expr_reconstruct_(f)

@memoized_handler
def indexed(self, o):
Expand All @@ -111,6 +99,9 @@ def indexed(self, o):
return o
return o._ufl_expr_reconstruct_(expr, i1)

reference_grad = _unary_operator
reference_value = _unary_operator


def remove_component_tensors(o):
"""Remove component tensors."""
Expand Down

0 comments on commit 83609f8

Please sign in to comment.