Skip to content

Commit

Permalink
Sharpen some types in loopy.symbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 6, 2024
1 parent 57fbd84 commit ab68888
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@
import pymbolic.primitives as p
import pytools.lex
from islpy import dim_type
from pymbolic import Variable
from pymbolic import ArithmeticExpressionT, Variable
from pymbolic.mapper import (
CachedCombineMapper as CombineMapperBase,
CachedIdentityMapper as IdentityMapperBase,
CachedWalkMapper as WalkMapperBase,
CallbackMapper as CallbackMapperBase,
CSECachingMapperMixin,
IdentityMapper as UncachedIdentityMapperBase,
Mapper,
P,
WalkMapper as UncachedWalkMapperBase,
)
from pymbolic.mapper.coefficient import CoefficientCollector as CoefficientCollectorBase
Expand All @@ -80,7 +82,7 @@
UnableToDetermineAccessRangeError,
)
from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible
from loopy.typing import ExpressionT
from loopy.typing import ExpressionT, auto


if TYPE_CHECKING:
Expand Down Expand Up @@ -128,11 +130,11 @@

# {{{ mappers with support for loopy-specific primitives

class IdentityMapperMixin:
def map_literal(self, expr, *args, **kwargs):
class IdentityMapperMixin(Mapper[ExpressionT, P]):
def map_literal(self, expr: Literal, *args, **kwargs):
return expr

def map_array_literal(self, expr, *args, **kwargs):
def map_array_literal(self, expr: ArrayLiteral, *args, **kwargs):
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
for ch in expr.children))

Expand Down Expand Up @@ -221,7 +223,7 @@ class UncachedIdentityMapper(UncachedIdentityMapperBase,


class PartialEvaluationMapper(
EvaluationMapperBase, CSECachingMapperMixin, IdentityMapperMixin):
EvaluationMapperBase, IdentityMapperMixin[P]):
def map_variable(self, expr):
return expr

Expand Down Expand Up @@ -315,7 +317,7 @@ def map_sub_array_ref(self, expr, *args, **kwargs):


class SubstitutionMapper(
CSECachingMapperMixin, SubstitutionMapperBase, IdentityMapperMixin):
SubstitutionMapperBase, IdentityMapperMixin[[]]):
def map_common_subexpression_uncached(self, expr):
return type(expr)(self.rec(expr.child), expr.prefix, expr.scope)

Expand All @@ -325,7 +327,7 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,
pass


class StringifyMapper(StringifyMapperBase):
class StringifyMapper(StringifyMapperBase[[]]):
def map_literal(self, expr, *args):
return expr.s

Expand Down Expand Up @@ -963,7 +965,7 @@ def _get_dependencies_and_reduction_inames(expr):
return deps, reduction_inames


def get_dependencies(expr: ExpressionT) -> AbstractSet[str]:
def get_dependencies(expr: ExpressionT | type[auto]) -> AbstractSet[str]:
return _get_dependencies_and_reduction_inames(expr)[0]


Expand Down Expand Up @@ -1706,7 +1708,7 @@ def map_subscript(self, expr):

# {{{ (pw)aff to expr conversion

def aff_to_expr(aff: isl.Aff) -> ExpressionT:
def aff_to_expr(aff: isl.Aff) -> ArithmeticExpressionT:
from pymbolic import var

denom = aff.get_denominator_val().to_python()
Expand Down

0 comments on commit ab68888

Please sign in to comment.