Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CachedMapper and use it for the IdentityMapper #92

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, to_eliminate, get_key):
self.get_key = get_key

self.canonical_subexprs = {}
super().__init__()

def get_cse(self, expr, key=None):
if key is None:
Expand Down
33 changes: 32 additions & 1 deletion pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from abc import ABC, abstractmethod
import pymbolic.primitives as primitives
from typing import Generic, TypeVar, Any, Dict

__doc__ = """
Basic dispatch
Expand Down Expand Up @@ -73,6 +74,8 @@
.. autoclass:: CSECachingMapperMixin
"""

CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper


try:
import numpy
Expand Down Expand Up @@ -200,6 +203,34 @@ def map_foreign(self, expr, *args, **kwargs):
RecursiveMapper = Mapper


# {{{ CachedMapper

class CachedMapper(Mapper, Generic[CachedMapperT]):
"""Mapper class that maps each subexpression exactly once. This loses some
information compared to :class:`Mapper` as a subexpression is visited only from
one of its predecessors.
"""

def __init__(self) -> None:
self._cache: Dict[CachedMapperT, Any] = {}

def cache_key(self, expr: CachedMapperT) -> Any:
return expr

# type-ignore-reason: incompatible with super class
# type: ignore[override]
def rec(self, expr: CachedMapperT, *args, **kwargs) -> Any:
key = self.cache_key(expr)
try:
return self._cache[key]
except KeyError:
result = super().rec(expr, *args, **kwargs) # type: ignore[type-var]
self._cache[key] = result
return result

# }}}


# {{{ combine mapper

class CombineMapper(RecursiveMapper):
Expand Down Expand Up @@ -364,7 +395,7 @@ def map_constant(self, expr):

# {{{ identity mapper

class IdentityMapper(Mapper):
class IdentityMapper(CachedMapper):
"""A :class:`Mapper` whose default mapper methods
make a deep copy of each subexpression.

Expand Down
1 change: 1 addition & 0 deletions pymbolic/mapper/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, parameters=None):
if parameters is None:
parameters = set()
self.parameters = parameters
super().__init__()

def get_dependencies(self, expr):
from pymbolic.mapper.dependency import DependencyMapper
Expand Down
1 change: 1 addition & 0 deletions pymbolic/mapper/constant_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, real_type, complex_type=None, integer_type=None):
self.complex_type = complex_type

self.integer_type = integer_type
super().__init__()

def map_constant(self, expr):
if expr.imag:
Expand Down
1 change: 1 addition & 0 deletions pymbolic/mapper/cse_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def visit(self, expr):
class CSETagMapper(IdentityMapper):
def __init__(self, walk_mapper):
self.subexpr_histogram = walk_mapper.subexpr_histogram
super().__init__()

def map_call(self, expr):
if self.subexpr_histogram.get(expr, 0) > 1:
Expand Down
1 change: 1 addition & 0 deletions pymbolic/mapper/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, collector=None, const_folder=None):

self.collector = collector
self.const_folder = const_folder
super().__init__()

def collect(self, expr):
return self.collector(self.const_folder(expr))
Expand Down
1 change: 1 addition & 0 deletions pymbolic/mapper/substitutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class SubstitutionMapper(pymbolic.mapper.IdentityMapper):
def __init__(self, subst_func):
self.subst_func = subst_func
super().__init__()

def map_variable(self, expr):
result = self.subst_func(expr)
Expand Down