Skip to content

Commit

Permalink
Indexed: propagate simplications to _ufl_expr_reconstruct_
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 4, 2025
1 parent e0c4e6a commit 6f755d4
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,12 @@ class Indexed(Operator):

def __new__(cls, expression, multiindex):
"""Create a new Indexed."""
# cyclic import
from ufl.tensors import ListTensor
try:
return expression[multiindex]
except ValueError:
pass

flattened = False
indices = multiindex.indices()

while (
len(indices) > 0
and isinstance(expression, ListTensor)
and isinstance(indices[0], FixedIndex)
):
# Simplify indexed ListTensor objects
expression = expression[indices[0]]
indices = indices[1:]
flattened = True

if isinstance(expression, Indexed):
# Simplify nested Indexed objects
indices = expression.ufl_operands[1].indices() + indices
expression = expression.ufl_operands[0]
flattened = True

if len(indices) == 0:
return expression
elif isinstance(expression, Zero):
Expand All @@ -65,9 +49,6 @@ def __new__(cls, expression, multiindex):
else:
fi, fid = (), ()
return Zero(shape=(), free_indices=fi, index_dimensions=fid)
elif flattened:
# Simplified Indexed expression
return Indexed(expression, MultiIndex(indices))
else:
return Operator.__new__(cls)

Expand Down Expand Up @@ -139,3 +120,10 @@ def __getitem__(self, key):
f"Attempting to index with {ufl_err_str(key)}, "
f"but object is already indexed: {ufl_err_str(self)}"
)

def _ufl_expr_reconstruct_(self, expression, multiindex):
"""Return a new object of the same type with new operands."""
try:
return expression[multiindex]
except ValueError:
return Operator._ufl_expr_reconstruct_(self, expression, multiindex)

0 comments on commit 6f755d4

Please sign in to comment.