diff --git a/test/test_simplify.py b/test/test_simplify.py index dcd3e06ba..80599dca6 100755 --- a/test/test_simplify.py +++ b/test/test_simplify.py @@ -32,11 +32,13 @@ triangle, ) from ufl.algorithms import compute_form_data -from ufl.core.multiindex import FixedIndex, MultiIndex +from ufl.constantvalue import Zero +from ufl.core.multiindex import FixedIndex, Index, MultiIndex, indices from ufl.finiteelement import FiniteElement from ufl.indexed import Indexed from ufl.pullback import identity_pullback from ufl.sobolevspace import H1 +from ufl.tensors import ComponentTensor, ListTensor def xtest_zero_times_argument(self): @@ -193,3 +195,76 @@ def test_nested_indexed(self): multiindex = MultiIndex((FixedIndex(0),)) assert Indexed(expr, multiindex) is expr[0] assert Indexed(expr, multiindex) is comps[1] + + +def test_repeated_indexing(self): + # Test that an Indexed with repeated indices does not contract indices + shape = (2, 2) + element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1) + domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1)) + space = FunctionSpace(domain, element) + x = Coefficient(space) + C = as_tensor([x, x]) + + fi = FixedIndex(0) + i = Index() + ii = MultiIndex((fi, i, i)) + expr = Indexed(C, ii) + assert i.count() in expr.ufl_free_indices + assert isinstance(expr, Indexed) + B, jj = expr.ufl_operands + assert B is x + assert tuple(jj) == tuple(ii[1:]) + + +def test_untangle_indexed_component_tensor(self): + shape = (2, 2, 2, 2) + element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1) + domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1)) + space = FunctionSpace(domain, element) + C = Coefficient(space) + + r = len(shape) + kk = indices(r) + + # Untangle as_tensor(C[kk], kk) -> C + B = as_tensor(Indexed(C, MultiIndex(kk)), kk) + assert B is C + + # Untangle as_tensor(C[kk], jj)[ii] -> C[ll] + jj = kk[2:] + A = as_tensor(Indexed(C, MultiIndex(kk)), jj) + assert A is not C + + ii = kk + expr = Indexed(A, MultiIndex(ii)) + assert isinstance(expr, Indexed) + B, ll = expr.ufl_operands + assert B is C + + rep = dict(zip(jj, ii)) + expected = tuple(rep.get(k, k) for k in kk) + assert tuple(ll) == expected + + +def test_simplify_indexed(self): + element = FiniteElement("Lagrange", triangle, 1, (3,), identity_pullback, H1) + domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1)) + space = FunctionSpace(domain, element) + u = Coefficient(space) + z = Zero(()) + i = Index() + j = Index() + # ListTensor + lt = ListTensor(z, z, u[1]) + assert Indexed(lt, MultiIndex((FixedIndex(2),))) == u[1] + # ListTensor -- nested + l0 = ListTensor(z, u[1], z) + l1 = ListTensor(z, z, u[2]) + l2 = ListTensor(u[0], z, z) + ll = ListTensor(l0, l1, l2) + assert Indexed(ll, MultiIndex((FixedIndex(1), FixedIndex(2)))) == u[2] + assert Indexed(ll, MultiIndex((FixedIndex(2), i))) == l2[i] + # ComponentTensor + ListTensor + c = ComponentTensor(Indexed(ll, MultiIndex((i, j))), MultiIndex((j, i))) + assert Indexed(c, MultiIndex((FixedIndex(1), FixedIndex(2)))) == l2[1] diff --git a/ufl/algorithms/apply_derivatives.py b/ufl/algorithms/apply_derivatives.py index da7b61da1..848c405f6 100644 --- a/ufl/algorithms/apply_derivatives.py +++ b/ufl/algorithms/apply_derivatives.py @@ -18,7 +18,6 @@ from ufl.checks import is_cellwise_constant from ufl.classes import ( Coefficient, - ComponentTensor, Conj, ConstantValue, ExprList, @@ -219,28 +218,12 @@ def variable(self, o, df, unused_l): # --- Indexing and component handling - def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in nesting rules + def indexed(self, o, Ap, ii): """Differentiate an indexed.""" # Propagate zeros if isinstance(Ap, Zero): return self.independent_operator(o) - # Untangle as_tensor(C[kk], jj)[ii] -> C[ll] to simplify - # resulting expression - if isinstance(Ap, ComponentTensor): - B, jj = Ap.ufl_operands - if isinstance(B, Indexed): - C, kk = B.ufl_operands - kk = list(kk) - if all(j in kk for j in jj): - rep = dict(zip(jj, ii)) - Cind = [rep.get(k, k) for k in kk] - expr = Indexed(C, MultiIndex(tuple(Cind))) - assert expr.ufl_free_indices == o.ufl_free_indices - assert expr.ufl_shape == o.ufl_shape - return expr - - # Otherwise a more generic approach r = len(Ap.ufl_shape) - len(ii) if r: kk = indices(r) @@ -1450,29 +1433,12 @@ def base_form_coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd): o_[3], ) - def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in generic rules + def indexed(self, o, Ap, ii): """Apply to an indexed.""" # Reuse if untouched if Ap is o.ufl_operands[0]: return o - # Untangle as_tensor(C[kk], jj)[ii] -> C[ll] to simplify - # resulting expression - if isinstance(Ap, ComponentTensor): - B, jj = Ap.ufl_operands - if isinstance(B, Indexed): - C, kk = B.ufl_operands - - kk = list(kk) - if all(j in kk for j in jj): - rep = dict(zip(jj, ii)) - Cind = [rep.get(k, k) for k in kk] - expr = Indexed(C, MultiIndex(tuple(Cind))) - assert expr.ufl_free_indices == o.ufl_free_indices - assert expr.ufl_shape == o.ufl_shape - return expr - - # Otherwise a more generic approach r = len(Ap.ufl_shape) - len(ii) if r: kk = indices(r) diff --git a/ufl/algorithms/expand_indices.py b/ufl/algorithms/expand_indices.py index 316998341..de994bd12 100644 --- a/ufl/algorithms/expand_indices.py +++ b/ufl/algorithms/expand_indices.py @@ -137,7 +137,7 @@ def index_sum(self, x): # TODO: For the list tensor purging algorithm, do something like: # if index not in self._to_expand: - # return self.expr(x, *[self.visit(o) for o in x.ufl_operands]) + # return self.expr(x, *map(self.visit, x.ufl_operands)) for value in range(x.dimension()): self._index2value.push(index, value) diff --git a/ufl/core/expr.py b/ufl/core/expr.py index 41b6e55a6..93149e2f6 100644 --- a/ufl/core/expr.py +++ b/ufl/core/expr.py @@ -342,6 +342,10 @@ def _ufl_err_str_(self): """Return a short string to represent this Expr in an error message.""" return f"<{self._ufl_class_.__name__} id={id(self)}>" + def _simplify_indexed(self, multiindex): + """Return a simplified Expr used in the constructor of Indexed(self, multiindex).""" + raise NotImplementedError(self.__class__._simplify_indexed) + # --- Special functions used for processing expressions --- def __eq__(self, other): diff --git a/ufl/corealg/map_dag.py b/ufl/corealg/map_dag.py index 9b9196f17..7cd6d11ee 100644 --- a/ufl/corealg/map_dag.py +++ b/ufl/corealg/map_dag.py @@ -100,7 +100,7 @@ def traversal(expression): if cutoff_types[v._ufl_typecode_]: r = handlers[v._ufl_typecode_](v) else: - r = handlers[v._ufl_typecode_](v, *[vcache[u] for u in v.ufl_operands]) + r = handlers[v._ufl_typecode_](v, *(vcache[u] for u in v.ufl_operands)) # Optionally check if r is in rcache, a memory optimization # to be able to keep representation of result compact diff --git a/ufl/index_combination_utils.py b/ufl/index_combination_utils.py index 8bd5087a8..50d4c5e17 100644 --- a/ufl/index_combination_utils.py +++ b/ufl/index_combination_utils.py @@ -215,11 +215,9 @@ def merge_overlapping_indices(afi, afid, bfi, bfid): # Find repeated indices, brute force version for i0 in range(an): - for i1 in range(bn): - if afi[i0] == bfi[i1]: - repeated_indices.append(afi[i0]) - repeated_index_dimensions.append(afid[i0]) - break + if afi[i0] in bfi: + repeated_indices.append(afi[i0]) + repeated_index_dimensions.append(afid[i0]) # Collect only non-repeated indices, brute force version for i, d in sorted(zip(afi + bfi, afid + bfid)): diff --git a/ufl/indexed.py b/ufl/indexed.py index 338033413..815f9d150 100644 --- a/ufl/indexed.py +++ b/ufl/indexed.py @@ -46,9 +46,10 @@ def __new__(cls, expression, multiindex): return Zero(shape=(), free_indices=fi, index_dimensions=fid) try: - # Simplify indexed ListTensor - return expression[multiindex] - except ValueError: + # Simplify if possible + return expression._simplify_indexed(multiindex) + except NotImplementedError: + # Construct a new instance to be initialised self = Operator.__new__(cls) self._initialised = False return self @@ -124,11 +125,3 @@ def __getitem__(self, key): f"Attempting to index with {ufl_err_str(key)}, " f"but object is already indexed: {ufl_err_str(self)}" ) - - def _ufl_expr_reconstruct_(self, expression, multiindex): - """Reconstruct.""" - try: - # Simplify indexed ListTensor - return expression[multiindex] - except ValueError: - return Operator._ufl_expr_reconstruct_(self, expression, multiindex) diff --git a/ufl/tensors.py b/ufl/tensors.py index b54b5eecf..4c00dd8e9 100644 --- a/ufl/tensors.py +++ b/ufl/tensors.py @@ -22,7 +22,7 @@ class ListTensor(Operator): """Wraps a list of expressions into a tensor valued expression of one higher rank.""" - __slots__ = () + __slots__ = ("_initialised",) def __new__(cls, *expressions): """Create a new ListTensor.""" @@ -88,10 +88,15 @@ def sub(e, *indices): if all(i[0] == k for k, i in enumerate(indices)): return sub(e0, 0, 0) - return Operator.__new__(cls) + # Construct a new instance to be initialised + self = Operator.__new__(cls) + self._initialised = False + return self def __init__(self, *expressions): """Initialise.""" + if self._initialised: + return Operator.__init__(self, expressions) # Checks @@ -100,6 +105,7 @@ def __init__(self, *expressions): raise ValueError( "Can't combine subtensor expressions with different sets of free indices." ) + self._initialised = True @property def ufl_shape(self): @@ -120,6 +126,14 @@ def evaluate(self, x, mapping, component, index_values, derivatives=()): else: return a.evaluate(x, mapping, component, index_values) + def _simplify_indexed(self, multiindex): + """Return a simplified Expr used in the constructor of Indexed(self, multiindex).""" + k = multiindex[0] + if isinstance(k, FixedIndex): + sub = self.ufl_operands[int(k)] + return Indexed(sub, MultiIndex(multiindex[1:])) + return Operator._simplify_indexed(self, multiindex) + def __getitem__(self, key): """Get an item.""" origkey = key @@ -128,6 +142,8 @@ def __getitem__(self, key): key = key.indices() if not isinstance(key, tuple): key = (key,) + if len(key) == 0: + return self k = key[0] if isinstance(k, (int, FixedIndex)): sub = self.ufl_operands[int(k)] @@ -160,11 +176,11 @@ def substring(expressions, indent): class ComponentTensor(Operator): """Maps the free indices of a scalar valued expression to tensor axes.""" - __slots__ = ("ufl_free_indices", "ufl_index_dimensions", "ufl_shape") + __slots__ = ("_initialised", "ufl_free_indices", "ufl_index_dimensions", "ufl_shape") def __new__(cls, expression, indices): """Create a new ComponentTensor.""" - # Simplify + # Zero-simplify if isinstance(expression, Zero): fi, fid, sh = remove_indices( expression.ufl_free_indices, @@ -173,11 +189,21 @@ def __new__(cls, expression, indices): ) return Zero(sh, fi, fid) - # Construct - return Operator.__new__(cls) + # Special case for simplification as_tensor(A[ii], ii) -> A + if isinstance(expression, Indexed): + A, ii = expression.ufl_operands + if indices == ii: + return A + + # Construct a new instance to be initialised + self = Operator.__new__(cls) + self._initialised = False + return self def __init__(self, expression, indices): """Initialise.""" + if self._initialised: + return if not isinstance(expression, Expr): raise ValueError("Expecting ufl expression.") if expression.ufl_shape != (): @@ -197,15 +223,21 @@ def __init__(self, expression, indices): self.ufl_free_indices = fi self.ufl_index_dimensions = fid self.ufl_shape = sh - - def _ufl_expr_reconstruct_(self, expressions, indices): - """Reconstruct.""" - # Special case for simplification as_tensor(A[ii], ii) -> A - if isinstance(expressions, Indexed): - A, ii = expressions.ufl_operands - if indices == ii: - return A - return Operator._ufl_expr_reconstruct_(self, expressions, indices) + self._initialised = True + + def _simplify_indexed(self, multiindex): + """Return a simplified Expr used in the constructor of Indexed(self, multiindex).""" + # Untangle as_tensor(C[kk], jj)[ii] -> C[ll] + B, jj = self.ufl_operands + if isinstance(B, Indexed): + C, kk = B.ufl_operands + if all(j in kk for j in jj): + ii = tuple(multiindex) + rep = dict(zip(jj, ii)) + Cind = tuple(rep.get(k, k) for k in kk) + return Indexed(C, MultiIndex(Cind)) + + return Operator._simplify_indexed(self, multiindex) def indices(self): """Get indices.""" @@ -213,8 +245,7 @@ def indices(self): def evaluate(self, x, mapping, component, index_values): """Evaluate.""" - indices = self.ufl_operands[1] - a = self.ufl_operands[0] + a, indices = self.ufl_operands if len(indices) != len(component): raise ValueError("Expecting a component matching the indices tuple.")