Skip to content

Commit

Permalink
Remove untangling of ComponentTensor from apply_derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 17, 2025
1 parent 958f11f commit 6ec8114
Showing 1 changed file with 2 additions and 35 deletions.
37 changes: 2 additions & 35 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,28 +219,12 @@ def variable(self, o, df, unused_l):

# --- Indexing and component handling

def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in nesting rules
def indexed(self, o, Ap, ii):
"""Differentiate an indexed."""
# Propagate zeros
if isinstance(Ap, Zero):
return self.independent_operator(o)

# Untangle as_tensor(C[kk], jj)[ii] -> C[ll] to simplify
# resulting expression
if isinstance(Ap, ComponentTensor):
B, jj = Ap.ufl_operands
if isinstance(B, Indexed):
C, kk = B.ufl_operands
kk = list(kk)
if all(j in kk for j in jj):
rep = dict(zip(jj, ii))
Cind = tuple(rep.get(k, k) for k in kk)
expr = Indexed(C, MultiIndex(Cind))
assert expr.ufl_free_indices == o.ufl_free_indices
assert expr.ufl_shape == o.ufl_shape
return expr

# Otherwise a more generic approach
r = len(Ap.ufl_shape) - len(ii)
if r:
kk = indices(r)
Expand Down Expand Up @@ -1450,29 +1434,12 @@ def base_form_coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
o_[3],
)

def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in generic rules
def indexed(self, o, Ap, ii):
"""Apply to an indexed."""
# Reuse if untouched
if Ap is o.ufl_operands[0]:
return o

# Untangle as_tensor(C[kk], jj)[ii] -> C[ll] to simplify
# resulting expression
if isinstance(Ap, ComponentTensor):
B, jj = Ap.ufl_operands
if isinstance(B, Indexed):
C, kk = B.ufl_operands

kk = list(kk)
if all(j in kk for j in jj):
rep = dict(zip(jj, ii))
Cind = tuple(rep.get(k, k) for k in kk)
expr = Indexed(C, MultiIndex(Cind))
assert expr.ufl_free_indices == o.ufl_free_indices
assert expr.ufl_shape == o.ufl_shape
return expr

# Otherwise a more generic approach
r = len(Ap.ufl_shape) - len(ii)
if r:
kk = indices(r)
Expand Down

0 comments on commit 6ec8114

Please sign in to comment.