diff --git a/pymbolic/cse.py b/pymbolic/cse.py index 5b4ac1ef..c0100413 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -78,6 +78,7 @@ def __init__(self, to_eliminate, get_key): self.get_key = get_key self.canonical_subexprs = {} + super().__init__() def get_cse(self, expr, key=None): if key is None: diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index a10b5f4f..0677be74 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -22,6 +22,7 @@ from abc import ABC, abstractmethod import pymbolic.primitives as primitives +from typing import Generic, TypeVar, Any, Dict __doc__ = """ Basic dispatch @@ -73,6 +74,8 @@ .. autoclass:: CSECachingMapperMixin """ +CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper + try: import numpy @@ -200,6 +203,34 @@ def map_foreign(self, expr, *args, **kwargs): RecursiveMapper = Mapper +# {{{ CachedMapper + +class CachedMapper(Mapper, Generic[CachedMapperT]): + """Mapper class that maps each subexpression exactly once. This loses some + information compared to :class:`Mapper` as a subexpression is visited only from + one of its predecessors. + """ + + def __init__(self) -> None: + self._cache: Dict[CachedMapperT, Any] = {} + + def cache_key(self, expr: CachedMapperT) -> Any: + return expr + + # type-ignore-reason: incompatible with super class + # type: ignore[override] + def rec(self, expr: CachedMapperT, *args, **kwargs) -> Any: + key = self.cache_key(expr) + try: + return self._cache[key] + except KeyError: + result = super().rec(expr, *args, **kwargs) # type: ignore[type-var] + self._cache[key] = result + return result + +# }}} + + # {{{ combine mapper class CombineMapper(RecursiveMapper): @@ -364,7 +395,7 @@ def map_constant(self, expr): # {{{ identity mapper -class IdentityMapper(Mapper): +class IdentityMapper(CachedMapper): """A :class:`Mapper` whose default mapper methods make a deep copy of each subexpression. diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py index a0922433..e60ec644 100644 --- a/pymbolic/mapper/collector.py +++ b/pymbolic/mapper/collector.py @@ -36,6 +36,7 @@ def __init__(self, parameters=None): if parameters is None: parameters = set() self.parameters = parameters + super().__init__() def get_dependencies(self, expr): from pymbolic.mapper.dependency import DependencyMapper diff --git a/pymbolic/mapper/constant_converter.py b/pymbolic/mapper/constant_converter.py index 3f512713..ea8c6620 100644 --- a/pymbolic/mapper/constant_converter.py +++ b/pymbolic/mapper/constant_converter.py @@ -51,6 +51,7 @@ def __init__(self, real_type, complex_type=None, integer_type=None): self.complex_type = complex_type self.integer_type = integer_type + super().__init__() def map_constant(self, expr): if expr.imag: diff --git a/pymbolic/mapper/cse_tagger.py b/pymbolic/mapper/cse_tagger.py index f6b4c45d..0cbf2975 100644 --- a/pymbolic/mapper/cse_tagger.py +++ b/pymbolic/mapper/cse_tagger.py @@ -37,6 +37,7 @@ def visit(self, expr): class CSETagMapper(IdentityMapper): def __init__(self, walk_mapper): self.subexpr_histogram = walk_mapper.subexpr_histogram + super().__init__() def map_call(self, expr): if self.subexpr_histogram.get(expr, 0) > 1: diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py index 13c9831c..3685e036 100644 --- a/pymbolic/mapper/distributor.py +++ b/pymbolic/mapper/distributor.py @@ -48,6 +48,7 @@ def __init__(self, collector=None, const_folder=None): self.collector = collector self.const_folder = const_folder + super().__init__() def collect(self, expr): return self.collector(self.const_folder(expr)) diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py index c775dd17..b3592946 100644 --- a/pymbolic/mapper/substitutor.py +++ b/pymbolic/mapper/substitutor.py @@ -26,6 +26,7 @@ class SubstitutionMapper(pymbolic.mapper.IdentityMapper): def __init__(self, subst_func): self.subst_func = subst_func + super().__init__() def map_variable(self, expr): result = self.subst_func(expr)