Skip to content

Commit

Permalink
optimise
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Jul 22, 2024
1 parent d574909 commit eedb48d
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions ufl/algorithms/apply_coefficient_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module contains classes and functions to split coefficients defined on mixed function spaces.
"""

import functools
from collections import defaultdict
import numpy
from ufl.algorithms.map_integrands import map_integrand_dags
Expand Down Expand Up @@ -120,9 +121,22 @@ class FixedIndexRemover(MultiFunction):
def __init__(self, fimap):
MultiFunction.__init__(self)
self.fimap = fimap
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@staticmethod
def _cached(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
o, = args
if o in self._object_cache:
return self._object_cache[o]
else:
return self._object_cache.setdefault(o, f(self, o))
return wrapper

@_cached
def zero(self, o):
free_indices = []
index_dimensions = []
Expand All @@ -137,14 +151,15 @@ def zero(self, o):
index_dimensions.append(d)
return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions))

@_cached
def list_tensor(self, o):
rule = FixedIndexRemover(self.fimap)
cc = []
for o1 in o.ufl_operands:
comp = map_expr_dag(rule, o1)
comp = map_expr_dag(self, o1)
cc.append(comp)
return ListTensor(*cc)

@_cached
def multi_index(self, o):
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))

Expand All @@ -153,38 +168,48 @@ class IndexRemover(MultiFunction):

def __init__(self):
MultiFunction.__init__(self)
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@staticmethod
def _cached(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
o, = args
if o in self._object_cache:
return self._object_cache[o]
else:
return self._object_cache.setdefault(o, f(self, o))
return wrapper

@_cached
def _zero_simplify(self, o):
operand, = o.ufl_operands
rule = IndexRemover()
operand = map_expr_dag(rule, operand)
operand = map_expr_dag(self, operand)
if isinstance(operand, Zero):
return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions)
else:
return o._ufl_expr_reconstruct_(operand)

@_cached
def indexed(self, o):
o1, i1 = o.ufl_operands
if isinstance(o1, ComponentTensor):
o2, i2 = o1.ufl_operands
fimap = dict(zip(i2.indices(), i1.indices(), strict=True))
rule = FixedIndexRemover(fimap)
v = map_expr_dag(rule, o2)
rule = IndexRemover()
return map_expr_dag(rule, v)
return map_expr_dag(self, v)
elif isinstance(o1, ListTensor):
if isinstance(i1[0], FixedIndex):
o1 = o1.ufl_operands[i1[0]._value]
rule = IndexRemover()
if len(i1) > 1:
i1 = MultiIndex(i1[1:])
return map_expr_dag(rule, Indexed(o1, i1))
return map_expr_dag(self, Indexed(o1, i1))
else:
return map_expr_dag(rule, o1)
rule = IndexRemover()
o1 = map_expr_dag(rule, o1)
return map_expr_dag(self, o1)
o1 = map_expr_dag(self, o1)
return Indexed(o1, i1)

# Do something nicer
Expand Down

0 comments on commit eedb48d

Please sign in to comment.