From 6f755d40593c63ca2ffc0a1b52b0614d398b920e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 4 Jan 2025 17:32:16 -0600 Subject: [PATCH] Indexed: propagate simplications to _ufl_expr_reconstruct_ --- ufl/indexed.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/ufl/indexed.py b/ufl/indexed.py index 422dc942b..45584094e 100644 --- a/ufl/indexed.py +++ b/ufl/indexed.py @@ -26,28 +26,12 @@ class Indexed(Operator): def __new__(cls, expression, multiindex): """Create a new Indexed.""" - # cyclic import - from ufl.tensors import ListTensor + try: + return expression[multiindex] + except ValueError: + pass - flattened = False indices = multiindex.indices() - - while ( - len(indices) > 0 - and isinstance(expression, ListTensor) - and isinstance(indices[0], FixedIndex) - ): - # Simplify indexed ListTensor objects - expression = expression[indices[0]] - indices = indices[1:] - flattened = True - - if isinstance(expression, Indexed): - # Simplify nested Indexed objects - indices = expression.ufl_operands[1].indices() + indices - expression = expression.ufl_operands[0] - flattened = True - if len(indices) == 0: return expression elif isinstance(expression, Zero): @@ -65,9 +49,6 @@ def __new__(cls, expression, multiindex): else: fi, fid = (), () return Zero(shape=(), free_indices=fi, index_dimensions=fid) - elif flattened: - # Simplified Indexed expression - return Indexed(expression, MultiIndex(indices)) else: return Operator.__new__(cls) @@ -139,3 +120,10 @@ def __getitem__(self, key): f"Attempting to index with {ufl_err_str(key)}, " f"but object is already indexed: {ufl_err_str(self)}" ) + + def _ufl_expr_reconstruct_(self, expression, multiindex): + """Return a new object of the same type with new operands.""" + try: + return expression[multiindex] + except ValueError: + return Operator._ufl_expr_reconstruct_(self, expression, multiindex)