diff --git a/ufl/algorithms/apply_coefficient_split.py b/ufl/algorithms/apply_coefficient_split.py index 8148b1e65..694599cd8 100644 --- a/ufl/algorithms/apply_coefficient_split.py +++ b/ufl/algorithms/apply_coefficient_split.py @@ -3,6 +3,7 @@ This module contains classes and functions to split coefficients defined on mixed function spaces. """ +import functools from collections import defaultdict import numpy from ufl.algorithms.map_integrands import map_integrand_dags @@ -120,9 +121,22 @@ class FixedIndexRemover(MultiFunction): def __init__(self, fimap): MultiFunction.__init__(self) self.fimap = fimap + self._object_cache = {} expr = MultiFunction.reuse_if_untouched + @staticmethod + def _cached(f): + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + o, = args + if o in self._object_cache: + return self._object_cache[o] + else: + return self._object_cache.setdefault(o, f(self, o)) + return wrapper + + @_cached def zero(self, o): free_indices = [] index_dimensions = [] @@ -137,14 +151,15 @@ def zero(self, o): index_dimensions.append(d) return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions)) + @_cached def list_tensor(self, o): - rule = FixedIndexRemover(self.fimap) cc = [] for o1 in o.ufl_operands: - comp = map_expr_dag(rule, o1) + comp = map_expr_dag(self, o1) cc.append(comp) return ListTensor(*cc) + @_cached def multi_index(self, o): return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices())) @@ -153,18 +168,31 @@ class IndexRemover(MultiFunction): def __init__(self): MultiFunction.__init__(self) + self._object_cache = {} expr = MultiFunction.reuse_if_untouched + @staticmethod + def _cached(f): + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + o, = args + if o in self._object_cache: + return self._object_cache[o] + else: + return self._object_cache.setdefault(o, f(self, o)) + return wrapper + + @_cached def _zero_simplify(self, o): operand, = o.ufl_operands - rule = IndexRemover() - operand = map_expr_dag(rule, operand) + operand = map_expr_dag(self, operand) if isinstance(operand, Zero): return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions) else: return o._ufl_expr_reconstruct_(operand) + @_cached def indexed(self, o): o1, i1 = o.ufl_operands if isinstance(o1, ComponentTensor): @@ -172,19 +200,16 @@ def indexed(self, o): fimap = dict(zip(i2.indices(), i1.indices(), strict=True)) rule = FixedIndexRemover(fimap) v = map_expr_dag(rule, o2) - rule = IndexRemover() - return map_expr_dag(rule, v) + return map_expr_dag(self, v) elif isinstance(o1, ListTensor): if isinstance(i1[0], FixedIndex): o1 = o1.ufl_operands[i1[0]._value] - rule = IndexRemover() if len(i1) > 1: i1 = MultiIndex(i1[1:]) - return map_expr_dag(rule, Indexed(o1, i1)) + return map_expr_dag(self, Indexed(o1, i1)) else: - return map_expr_dag(rule, o1) - rule = IndexRemover() - o1 = map_expr_dag(rule, o1) + return map_expr_dag(self, o1) + o1 = map_expr_dag(self, o1) return Indexed(o1, i1) # Do something nicer