Skip to content

Commit

Permalink
Indexed: avoid contraction of repeated indices (FEniCS#338)
Browse files Browse the repository at this point in the history
* Indexed: avoid contraction of repeated indices

* Untangle ComponentTensor early

* Indexed: _simplify_indexed

* ruff

* comments

* Remove untangling of ComponentTensor from apply_derivatives

* add tests

* Suggestions from review

* add another test

* Update test/test_simplify.py
  • Loading branch information
pbrubeck authored Jan 17, 2025
1 parent cb9052d commit 7d7c676
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 72 deletions.
77 changes: 76 additions & 1 deletion test/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
triangle,
)
from ufl.algorithms import compute_form_data
from ufl.core.multiindex import FixedIndex, MultiIndex
from ufl.constantvalue import Zero
from ufl.core.multiindex import FixedIndex, Index, MultiIndex, indices
from ufl.finiteelement import FiniteElement
from ufl.indexed import Indexed
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1
from ufl.tensors import ComponentTensor, ListTensor


def xtest_zero_times_argument(self):
Expand Down Expand Up @@ -193,3 +195,76 @@ def test_nested_indexed(self):
multiindex = MultiIndex((FixedIndex(0),))
assert Indexed(expr, multiindex) is expr[0]
assert Indexed(expr, multiindex) is comps[1]


def test_repeated_indexing(self):
# Test that an Indexed with repeated indices does not contract indices
shape = (2, 2)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
x = Coefficient(space)
C = as_tensor([x, x])

fi = FixedIndex(0)
i = Index()
ii = MultiIndex((fi, i, i))
expr = Indexed(C, ii)
assert i.count() in expr.ufl_free_indices
assert isinstance(expr, Indexed)
B, jj = expr.ufl_operands
assert B is x
assert tuple(jj) == tuple(ii[1:])


def test_untangle_indexed_component_tensor(self):
shape = (2, 2, 2, 2)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
C = Coefficient(space)

r = len(shape)
kk = indices(r)

# Untangle as_tensor(C[kk], kk) -> C
B = as_tensor(Indexed(C, MultiIndex(kk)), kk)
assert B is C

# Untangle as_tensor(C[kk], jj)[ii] -> C[ll]
jj = kk[2:]
A = as_tensor(Indexed(C, MultiIndex(kk)), jj)
assert A is not C

ii = kk
expr = Indexed(A, MultiIndex(ii))
assert isinstance(expr, Indexed)
B, ll = expr.ufl_operands
assert B is C

rep = dict(zip(jj, ii))
expected = tuple(rep.get(k, k) for k in kk)
assert tuple(ll) == expected


def test_simplify_indexed(self):
element = FiniteElement("Lagrange", triangle, 1, (3,), identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
u = Coefficient(space)
z = Zero(())
i = Index()
j = Index()
# ListTensor
lt = ListTensor(z, z, u[1])
assert Indexed(lt, MultiIndex((FixedIndex(2),))) == u[1]
# ListTensor -- nested
l0 = ListTensor(z, u[1], z)
l1 = ListTensor(z, z, u[2])
l2 = ListTensor(u[0], z, z)
ll = ListTensor(l0, l1, l2)
assert Indexed(ll, MultiIndex((FixedIndex(1), FixedIndex(2)))) == u[2]
assert Indexed(ll, MultiIndex((FixedIndex(2), i))) == l2[i]
# ComponentTensor + ListTensor
c = ComponentTensor(Indexed(ll, MultiIndex((i, j))), MultiIndex((j, i)))
assert Indexed(c, MultiIndex((FixedIndex(1), FixedIndex(2)))) == l2[1]
38 changes: 2 additions & 36 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ufl.checks import is_cellwise_constant
from ufl.classes import (
Coefficient,
ComponentTensor,
Conj,
ConstantValue,
ExprList,
Expand Down Expand Up @@ -219,28 +218,12 @@ def variable(self, o, df, unused_l):

# --- Indexing and component handling

def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in nesting rules
def indexed(self, o, Ap, ii):
"""Differentiate an indexed."""
# Propagate zeros
if isinstance(Ap, Zero):
return self.independent_operator(o)

# Untangle as_tensor(C[kk], jj)[ii] -> C[ll] to simplify
# resulting expression
if isinstance(Ap, ComponentTensor):
B, jj = Ap.ufl_operands
if isinstance(B, Indexed):
C, kk = B.ufl_operands
kk = list(kk)
if all(j in kk for j in jj):
rep = dict(zip(jj, ii))
Cind = [rep.get(k, k) for k in kk]
expr = Indexed(C, MultiIndex(tuple(Cind)))
assert expr.ufl_free_indices == o.ufl_free_indices
assert expr.ufl_shape == o.ufl_shape
return expr

# Otherwise a more generic approach
r = len(Ap.ufl_shape) - len(ii)
if r:
kk = indices(r)
Expand Down Expand Up @@ -1450,29 +1433,12 @@ def base_form_coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
o_[3],
)

def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in generic rules
def indexed(self, o, Ap, ii):
"""Apply to an indexed."""
# Reuse if untouched
if Ap is o.ufl_operands[0]:
return o

# Untangle as_tensor(C[kk], jj)[ii] -> C[ll] to simplify
# resulting expression
if isinstance(Ap, ComponentTensor):
B, jj = Ap.ufl_operands
if isinstance(B, Indexed):
C, kk = B.ufl_operands

kk = list(kk)
if all(j in kk for j in jj):
rep = dict(zip(jj, ii))
Cind = [rep.get(k, k) for k in kk]
expr = Indexed(C, MultiIndex(tuple(Cind)))
assert expr.ufl_free_indices == o.ufl_free_indices
assert expr.ufl_shape == o.ufl_shape
return expr

# Otherwise a more generic approach
r = len(Ap.ufl_shape) - len(ii)
if r:
kk = indices(r)
Expand Down
2 changes: 1 addition & 1 deletion ufl/algorithms/expand_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def index_sum(self, x):

# TODO: For the list tensor purging algorithm, do something like:
# if index not in self._to_expand:
# return self.expr(x, *[self.visit(o) for o in x.ufl_operands])
# return self.expr(x, *map(self.visit, x.ufl_operands))

for value in range(x.dimension()):
self._index2value.push(index, value)
Expand Down
4 changes: 4 additions & 0 deletions ufl/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def _ufl_err_str_(self):
"""Return a short string to represent this Expr in an error message."""
return f"<{self._ufl_class_.__name__} id={id(self)}>"

def _simplify_indexed(self, multiindex):
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
raise NotImplementedError(self.__class__._simplify_indexed)

# --- Special functions used for processing expressions ---

def __eq__(self, other):
Expand Down
2 changes: 1 addition & 1 deletion ufl/corealg/map_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def traversal(expression):
if cutoff_types[v._ufl_typecode_]:
r = handlers[v._ufl_typecode_](v)
else:
r = handlers[v._ufl_typecode_](v, *[vcache[u] for u in v.ufl_operands])
r = handlers[v._ufl_typecode_](v, *(vcache[u] for u in v.ufl_operands))

# Optionally check if r is in rcache, a memory optimization
# to be able to keep representation of result compact
Expand Down
8 changes: 3 additions & 5 deletions ufl/index_combination_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,9 @@ def merge_overlapping_indices(afi, afid, bfi, bfid):

# Find repeated indices, brute force version
for i0 in range(an):
for i1 in range(bn):
if afi[i0] == bfi[i1]:
repeated_indices.append(afi[i0])
repeated_index_dimensions.append(afid[i0])
break
if afi[i0] in bfi:
repeated_indices.append(afi[i0])
repeated_index_dimensions.append(afid[i0])

# Collect only non-repeated indices, brute force version
for i, d in sorted(zip(afi + bfi, afid + bfid)):
Expand Down
15 changes: 4 additions & 11 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def __new__(cls, expression, multiindex):
return Zero(shape=(), free_indices=fi, index_dimensions=fid)

try:
# Simplify indexed ListTensor
return expression[multiindex]
except ValueError:
# Simplify if possible
return expression._simplify_indexed(multiindex)
except NotImplementedError:
# Construct a new instance to be initialised
self = Operator.__new__(cls)
self._initialised = False
return self
Expand Down Expand Up @@ -124,11 +125,3 @@ def __getitem__(self, key):
f"Attempting to index with {ufl_err_str(key)}, "
f"but object is already indexed: {ufl_err_str(self)}"
)

def _ufl_expr_reconstruct_(self, expression, multiindex):
"""Reconstruct."""
try:
# Simplify indexed ListTensor
return expression[multiindex]
except ValueError:
return Operator._ufl_expr_reconstruct_(self, expression, multiindex)
65 changes: 48 additions & 17 deletions ufl/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class ListTensor(Operator):
"""Wraps a list of expressions into a tensor valued expression of one higher rank."""

__slots__ = ()
__slots__ = ("_initialised",)

def __new__(cls, *expressions):
"""Create a new ListTensor."""
Expand Down Expand Up @@ -88,10 +88,15 @@ def sub(e, *indices):
if all(i[0] == k for k, i in enumerate(indices)):
return sub(e0, 0, 0)

return Operator.__new__(cls)
# Construct a new instance to be initialised
self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, *expressions):
"""Initialise."""
if self._initialised:
return
Operator.__init__(self, expressions)

# Checks
Expand All @@ -100,6 +105,7 @@ def __init__(self, *expressions):
raise ValueError(
"Can't combine subtensor expressions with different sets of free indices."
)
self._initialised = True

@property
def ufl_shape(self):
Expand All @@ -120,6 +126,14 @@ def evaluate(self, x, mapping, component, index_values, derivatives=()):
else:
return a.evaluate(x, mapping, component, index_values)

def _simplify_indexed(self, multiindex):
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
k = multiindex[0]
if isinstance(k, FixedIndex):
sub = self.ufl_operands[int(k)]
return Indexed(sub, MultiIndex(multiindex[1:]))
return Operator._simplify_indexed(self, multiindex)

def __getitem__(self, key):
"""Get an item."""
origkey = key
Expand All @@ -128,6 +142,8 @@ def __getitem__(self, key):
key = key.indices()
if not isinstance(key, tuple):
key = (key,)
if len(key) == 0:
return self
k = key[0]
if isinstance(k, (int, FixedIndex)):
sub = self.ufl_operands[int(k)]
Expand Down Expand Up @@ -160,11 +176,11 @@ def substring(expressions, indent):
class ComponentTensor(Operator):
"""Maps the free indices of a scalar valued expression to tensor axes."""

__slots__ = ("ufl_free_indices", "ufl_index_dimensions", "ufl_shape")
__slots__ = ("_initialised", "ufl_free_indices", "ufl_index_dimensions", "ufl_shape")

def __new__(cls, expression, indices):
"""Create a new ComponentTensor."""
# Simplify
# Zero-simplify
if isinstance(expression, Zero):
fi, fid, sh = remove_indices(
expression.ufl_free_indices,
Expand All @@ -173,11 +189,21 @@ def __new__(cls, expression, indices):
)
return Zero(sh, fi, fid)

# Construct
return Operator.__new__(cls)
# Special case for simplification as_tensor(A[ii], ii) -> A
if isinstance(expression, Indexed):
A, ii = expression.ufl_operands
if indices == ii:
return A

# Construct a new instance to be initialised
self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, expression, indices):
"""Initialise."""
if self._initialised:
return
if not isinstance(expression, Expr):
raise ValueError("Expecting ufl expression.")
if expression.ufl_shape != ():
Expand All @@ -197,24 +223,29 @@ def __init__(self, expression, indices):
self.ufl_free_indices = fi
self.ufl_index_dimensions = fid
self.ufl_shape = sh

def _ufl_expr_reconstruct_(self, expressions, indices):
"""Reconstruct."""
# Special case for simplification as_tensor(A[ii], ii) -> A
if isinstance(expressions, Indexed):
A, ii = expressions.ufl_operands
if indices == ii:
return A
return Operator._ufl_expr_reconstruct_(self, expressions, indices)
self._initialised = True

def _simplify_indexed(self, multiindex):
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
# Untangle as_tensor(C[kk], jj)[ii] -> C[ll]
B, jj = self.ufl_operands
if isinstance(B, Indexed):
C, kk = B.ufl_operands
if all(j in kk for j in jj):
ii = tuple(multiindex)
rep = dict(zip(jj, ii))
Cind = tuple(rep.get(k, k) for k in kk)
return Indexed(C, MultiIndex(Cind))

return Operator._simplify_indexed(self, multiindex)

def indices(self):
"""Get indices."""
return self.ufl_operands[1]

def evaluate(self, x, mapping, component, index_values):
"""Evaluate."""
indices = self.ufl_operands[1]
a = self.ufl_operands[0]
a, indices = self.ufl_operands

if len(indices) != len(component):
raise ValueError("Expecting a component matching the indices tuple.")
Expand Down

0 comments on commit 7d7c676

Please sign in to comment.