Skip to content

Commit

Permalink
Untangle ComponentTensor early
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 16, 2025
1 parent 66d0601 commit d2e31a7
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class Indexed(Operator):

def __new__(cls, expression, multiindex):
"""Create a new Indexed."""
# Cyclic import
from ufl.tensors import ComponentTensor, ListTensor

if len(multiindex) == 0:
return expression
if isinstance(expression, Zero):
Expand All @@ -45,17 +48,27 @@ def __new__(cls, expression, multiindex):
fi, fid = (), ()
return Zero(shape=(), free_indices=fi, index_dimensions=fid)

try:
ii = tuple(multiindex)
if isinstance(expression, ListTensor) and isinstance(ii[0], FixedIndex):
# Simplify indexed ListTensor
# The multiindex needs to be split to avoid
# Expr.__getitem__ contracting repeated indices
ii = tuple(multiindex)
e0 = expression[MultiIndex(ii[:1])]
return e0 if len(ii) == 1 else Indexed(e0, MultiIndex(ii[1:]))
except ValueError:
self = Operator.__new__(cls)
self._initialised = False
return self
C = expression.ufl_operands[int(ii[0])]
return C if len(ii) == 1 else Indexed(C, MultiIndex(ii[1:]))

if isinstance(expression, ComponentTensor):
# Untangle as_tensor(C[kk], jj)[ii] -> C[ll]
B, jj = expression.ufl_operands
if isinstance(B, Indexed):
C, kk = B.ufl_operands

kk = list(kk)
if all(j in kk for j in jj):
rep = dict(zip(jj, ii))
Cind = tuple(rep.get(k, k) for k in kk)
return Indexed(C, MultiIndex(Cind))

self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, expression, multiindex):
"""Initialise."""
Expand Down

0 comments on commit d2e31a7

Please sign in to comment.