From ab68888686f2108e406515f39d95897605d5e993 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 6 Nov 2024 14:44:06 -0600 Subject: [PATCH] Sharpen some types in loopy.symbolic --- loopy/symbolic.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 9b1af41e2..1435f6943 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -47,7 +47,7 @@ import pymbolic.primitives as p import pytools.lex from islpy import dim_type -from pymbolic import Variable +from pymbolic import ArithmeticExpressionT, Variable from pymbolic.mapper import ( CachedCombineMapper as CombineMapperBase, CachedIdentityMapper as IdentityMapperBase, @@ -55,6 +55,8 @@ CallbackMapper as CallbackMapperBase, CSECachingMapperMixin, IdentityMapper as UncachedIdentityMapperBase, + Mapper, + P, WalkMapper as UncachedWalkMapperBase, ) from pymbolic.mapper.coefficient import CoefficientCollector as CoefficientCollectorBase @@ -80,7 +82,7 @@ UnableToDetermineAccessRangeError, ) from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible -from loopy.typing import ExpressionT +from loopy.typing import ExpressionT, auto if TYPE_CHECKING: @@ -128,11 +130,11 @@ # {{{ mappers with support for loopy-specific primitives -class IdentityMapperMixin: - def map_literal(self, expr, *args, **kwargs): +class IdentityMapperMixin(Mapper[ExpressionT, P]): + def map_literal(self, expr: Literal, *args, **kwargs): return expr - def map_array_literal(self, expr, *args, **kwargs): + def map_array_literal(self, expr: ArrayLiteral, *args, **kwargs): return type(expr)(tuple(self.rec(ch, *args, **kwargs) for ch in expr.children)) @@ -221,7 +223,7 @@ class UncachedIdentityMapper(UncachedIdentityMapperBase, class PartialEvaluationMapper( - EvaluationMapperBase, CSECachingMapperMixin, IdentityMapperMixin): + EvaluationMapperBase, IdentityMapperMixin[P]): def map_variable(self, expr): return expr @@ -315,7 +317,7 @@ def map_sub_array_ref(self, expr, *args, **kwargs): class SubstitutionMapper( - CSECachingMapperMixin, SubstitutionMapperBase, IdentityMapperMixin): + SubstitutionMapperBase, IdentityMapperMixin[[]]): def map_common_subexpression_uncached(self, expr): return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) @@ -325,7 +327,7 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase, pass -class StringifyMapper(StringifyMapperBase): +class StringifyMapper(StringifyMapperBase[[]]): def map_literal(self, expr, *args): return expr.s @@ -963,7 +965,7 @@ def _get_dependencies_and_reduction_inames(expr): return deps, reduction_inames -def get_dependencies(expr: ExpressionT) -> AbstractSet[str]: +def get_dependencies(expr: ExpressionT | type[auto]) -> AbstractSet[str]: return _get_dependencies_and_reduction_inames(expr)[0] @@ -1706,7 +1708,7 @@ def map_subscript(self, expr): # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ExpressionT: +def aff_to_expr(aff: isl.Aff) -> ArithmeticExpressionT: from pymbolic import var denom = aff.get_denominator_val().to_python()