From cc7efdc79ddd4c5e19469c50acf9c872312036f1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 4 Nov 2024 13:31:00 -0600 Subject: [PATCH] ruff --- loopy/statistics.py | 137 ++++++++++++++++++++++------------------ test/test_statistics.py | 3 +- 2 files changed, 78 insertions(+), 62 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 298a5f488..944d51347 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1,5 +1,6 @@ from __future__ import annotations + __copyright__ = """ Copyright (C) 2015 James Stevens Copyright (C) 2018 Kaushik Kulkarni @@ -27,21 +28,35 @@ THE SOFTWARE. """ +from dataclasses import dataclass, replace +from enum import Enum, auto as enum_auto from functools import cached_property, partial +from typing import Any, Callable, Generic, Iterable, Mapping, Optional, TypeVar, Union import islpy as isl -from islpy import dim_type +import pymbolic.primitives as p +from islpy import PwQPolynomial, dim_type from pymbolic.mapper import CombineMapper -from pytools import ImmutableRecord, memoize_method +from pytools import memoize_method +from pytools.tag import Tag import loopy as lp from loopy.diagnostic import LoopyError, warn_with_kernel -from loopy.kernel.data import AddressSpace, MultiAssignmentBase, TemporaryVariable -from loopy.kernel.function_interface import CallableKernel -from loopy.symbolic import CoefficientCollector, flatten +from loopy.kernel import LoopKernel +from loopy.kernel.array import ArrayBase +from loopy.kernel.data import AddressSpace, InameImplementationTag, MultiAssignmentBase +from loopy.kernel.function_interface import CallableKernel, InKernelCallable +from loopy.kernel.instruction import InstructionBase +from loopy.symbolic import ( + CoefficientCollector, + Reduction, + SubArrayRef, + TaggedExpression, + flatten, +) from loopy.translation_unit import TranslationUnit -from loopy.typing import Expression from loopy.types import LoopyType +from loopy.typing import Expression __doc__ = """ @@ -99,7 +114,7 @@ def get_kernel_zero_pwqpolynomial(kernel: LoopKernel) -> PwQPolynomial: # {{{ GuardedPwQPolynomial -def _get_param_tuple(obj) -> Tuple[str, ...]: +def _get_param_tuple(obj) -> tuple[str, ...]: return tuple( obj.get_dim_name(dim_type.param, i) for i in range(obj.dim(dim_type.param))) @@ -202,9 +217,9 @@ class ToCountMap(Generic[CountT]): """ - count_map: Dict[Countable, CountT] + count_map: dict[Countable, CountT] - def __init__(self, count_map: Optional[Dict[Countable, CountT]] = None) -> None: + def __init__(self, count_map: dict[Countable, CountT] | None = None) -> None: if count_map is None: count_map = {} @@ -269,7 +284,7 @@ def values(self): return self.count_map.values() def copy( - self, count_map: Optional[Dict[Countable, CountT]] = None + self, count_map: Optional[dict[Countable, CountT]] = None ) -> ToCountMap[CountT]: if count_map is None: count_map = self.count_map @@ -406,7 +421,7 @@ def group_by(self, *args) -> ToCountMap[CountT]: """ - new_count_map: Dict[Countable, CountT] = {} + new_count_map: dict[Countable, CountT] = {} # make sure all item keys have same type if self.count_map: @@ -490,7 +505,7 @@ class ToCountPolynomialMap(ToCountMap[GuardedPwQPolynomial]): def __init__( self, space: isl.Space, - count_map: Dict[Countable, GuardedPwQPolynomial] + count_map: dict[Countable, GuardedPwQPolynomial] ) -> None: if not isinstance(space, isl.Space): raise TypeError( @@ -527,7 +542,7 @@ def copy(self, count_map=None, space=None): return type(self)(space, count_map) - def eval_and_sum(self, params: Optional[Mapping[str, int]] = None) -> int: + def eval_and_sum(self, params: Mapping[str, int] | None = None) -> int: """Add all counts and evaluate with provided parameter dict *params* :return: An :class:`int` containing the sum of all counts @@ -578,7 +593,7 @@ def subst_into_to_count_map( tcm: ToCountPolynomialMap, subst_dict: Mapping[str, PwQPolynomial]) -> ToCountPolynomialMap: from loopy.isl_helpers import subst_into_pwqpolynomial - new_count_map: Dict[Countable, GuardedPwQPolynomial] = {} + new_count_map: dict[Countable, GuardedPwQPolynomial] = {} for key, value in tcm.count_map.items(): if isinstance(value, GuardedPwQPolynomial): new_count_map[key] = subst_into_guarded_pwqpolynomial( @@ -684,11 +699,11 @@ class Op: A :class:`frozenset` of tags to the operation. """ - dtype: Optional[LoopyType] = None - op_type: Optional[OpType] = None - count_granularity: Optional[CountGranularity] = None - kernel_name: Optional[str] = None - tags: FrozenSet[Tag] = frozenset() + dtype: LoopyType | None = None + op_type: OpType | None = None + count_granularity: CountGranularity | None = None + kernel_name: str | None = None + tags: frozenset[Tag] = frozenset() def __repr__(self): if self.kernel_name is not None: @@ -788,10 +803,10 @@ class MemAccess: read_write: Optional[AccessDirection] = None variable: Optional[str] = None - variable_tags: FrozenSet[Tag] = frozenset() + variable_tags: frozenset[Tag] = frozenset() count_granularity: Optional[CountGranularity] = None kernel_name: Optional[str] = None - tags: FrozenSet[Tag] = frozenset() + tags: frozenset[Tag] = frozenset() @property def mtype(self) -> str: @@ -867,7 +882,7 @@ class Sync: """ sync_kind: Optional[SynchronizationKind] = None kernel_name: Optional[str] = None - tags: FrozenSet[Tag] = frozenset() + tags: frozenset[Tag] = frozenset() def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness @@ -903,16 +918,16 @@ def combine(self, values: Iterable[ToCountMap]) -> ToCountPolynomialMap: return sum(values, self._new_zero_map()) def map_tagged_expression( - self, expr: TaggedExpression, tags: FrozenSet[Tag] + self, expr: TaggedExpression, tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self.rec(expr.expr, expr.tags) def map_constant( - self, expr: Expression, tags: FrozenSet[Tag] + self, expr: Expression, tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self._new_zero_map() - def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -941,18 +956,18 @@ def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: raise NotImplementedError() def map_call_with_kwargs( - self, expr: p.CallWithKwargs, tags: FrozenSet[Tag] + self, expr: p.CallWithKwargs, tags: frozenset[Tag] ) -> ToCountPolynomialMap: # See https://github.com/inducer/loopy/pull/323 raise NotImplementedError def map_comparison( - self, expr: p.Comparison, tags: FrozenSet[Tag] + self, expr: p.Comparison, tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self.rec(expr.left, tags) + self.rec(expr.right, tags) def map_if( - self, expr: p.If, tags: FrozenSet[Tag] + self, expr: p.If, tags: frozenset[Tag] ) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_if_branches", "%s counting sum of if-expression branches." @@ -961,7 +976,7 @@ def map_if( + self.rec(expr.else_, tags) def map_if_positive( - self, expr: p.IfPositive, tags: FrozenSet[Tag]) -> ToCountMap: + self, expr: p.IfPositive, tags: frozenset[Tag]) -> ToCountMap: warn_with_kernel(self.knl, "summing_if_branches", "%s counting sum of if-expression branches." % type(self).__name__) @@ -969,7 +984,7 @@ def map_if_positive( + self.rec(expr.else_, tags) def map_common_subexpression( - self, expr: p.CommonSubexpression, tags: FrozenSet[Tag] + self, expr: p.CommonSubexpression, tags: frozenset[Tag] ) -> ToCountPolynomialMap: raise RuntimeError("%s encountered %s--not supposed to happen" % (type(self).__name__, type(expr).__name__)) @@ -979,13 +994,13 @@ def map_common_subexpression( map_slice = map_common_subexpression def map_reduction( - self, expr: Reduction, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: Reduction, tags: frozenset[Tag]) -> ToCountPolynomialMap: # preprocessing should have removed these raise RuntimeError("%s encountered %s--not supposed to happen" % (type(self).__name__, type(expr).__name__)) def __call__( - self, expr, tags: Optional[FrozenSet[Tag]] = None + self, expr, tags: Optional[frozenset[Tag]] = None ) -> ToCountPolynomialMap: if tags is None: tags = frozenset() @@ -1004,14 +1019,14 @@ def __init__(self, knl: LoopKernel, callables_table, kernel_rec, arithmetic_count_granularity = CountGranularity.SUBGROUP - def map_constant(self, expr: Any, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_constant(self, expr: Any, tags: frozenset[Tag]) -> ToCountPolynomialMap: return self._new_zero_map() map_tagged_variable = map_constant map_variable = map_constant map_nan = map_constant - def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -1029,18 +1044,18 @@ def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: return super().map_call(expr, tags) def map_subscript( - self, expr: p.Subscript, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: p.Subscript, tags: frozenset[Tag]) -> ToCountPolynomialMap: if self.count_within_subscripts: return self.rec(expr.index, tags) else: return self._new_zero_map() def map_sub_array_ref( - self, expr: SubArrayRef, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: SubArrayRef, tags: frozenset[Tag]) -> ToCountPolynomialMap: # generates an array view, considered free return self._new_zero_map() - def map_sum(self, expr: p.Sum, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_sum(self, expr: p.Sum, tags: frozenset[Tag]) -> ToCountPolynomialMap: assert expr.children return self.new_poly_map( {Op(dtype=self.type_inf(expr), @@ -1052,7 +1067,7 @@ def map_sum(self, expr: p.Sum, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: ) + sum(self.rec(child, tags) for child in expr.children) def map_product( - self, expr: p.Product, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: p.Product, tags: frozenset[Tag]) -> ToCountPolynomialMap: from pymbolic.primitives import is_zero assert expr.children return sum(self.new_poly_map({Op(dtype=self.type_inf(expr), @@ -1072,7 +1087,7 @@ def map_product( kernel_name=self.knl.name): -self.one}) def map_quotient( - self, expr: p.QuotientBase, tags: FrozenSet[Tag] + self, expr: p.QuotientBase, tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.DIV, @@ -1085,7 +1100,7 @@ def map_quotient( map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr: p.Power, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_power(self, expr: p.Power, tags: frozenset[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.POW, tags=tags, @@ -1095,7 +1110,7 @@ def map_power(self, expr: p.Power, tags: FrozenSet[Tag]) -> ToCountPolynomialMap + self.rec(expr.exponent, tags) def map_left_shift( - self, expr: Union[p.LeftShift, p.RightShift], tags: FrozenSet[Tag] + self, expr: Union[p.LeftShift, p.RightShift], tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.SHIFT, @@ -1108,7 +1123,7 @@ def map_left_shift( map_right_shift = map_left_shift def map_bitwise_not( - self, expr: p.BitwiseNot, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: p.BitwiseNot, tags: frozenset[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.BITWISE, tags=tags, @@ -1118,7 +1133,7 @@ def map_bitwise_not( def map_bitwise_or( self, expr: Union[p.BitwiseOr, p.BitwiseAnd, p.BitwiseXor], - tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + tags: frozenset[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.BITWISE, tags=tags, @@ -1130,7 +1145,7 @@ def map_bitwise_or( map_bitwise_xor = map_bitwise_or map_bitwise_and = map_bitwise_or - def map_if(self, expr: p.If, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_if(self, expr: p.If, tags: frozenset[Tag]) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_if_branches_ops", "ExpressionOpCounter counting ops as sum of " "if-statement branches.") @@ -1138,7 +1153,7 @@ def map_if(self, expr: p.If, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self.rec(expr.else_, tags) def map_if_positive( - self, expr: p.IfPositive, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: p.IfPositive, tags: frozenset[Tag]) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_ifpos_branches_ops", "ExpressionOpCounter counting ops as sum of " "if_pos-statement branches.") @@ -1146,7 +1161,7 @@ def map_if_positive( + self.rec(expr.else_, tags) def map_min( - self, expr: Union[p. Min, p.Max], tags: FrozenSet[Tag] + self, expr: Union[p. Min, p.Max], tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.MAXMIN, @@ -1196,8 +1211,8 @@ def map_floor_div(self, expr): # {{{ _get_lid_and_gid_strides def _get_lid_and_gid_strides( - knl: LoopKernel, array: ArrayBase, index: Tuple[Expression, ...] - ) -> Tuple[Mapping[int, Expression], Mapping[int, Expression]]: + knl: LoopKernel, array: ArrayBase, index: tuple[Expression, ...] + ) -> tuple[Mapping[int, Expression], Mapping[int, Expression]]: # find all local and global index tags and corresponding inames from loopy.symbolic import get_dependencies my_inames = get_dependencies(index) & knl.all_inames() @@ -1294,11 +1309,11 @@ def get_iname_strides( class MemAccessCounter(CounterBase): def map_sub_array_ref( - self, expr: SubArrayRef, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: SubArrayRef, tags: frozenset[Tag]) -> ToCountPolynomialMap: # generates an array view, considered free return self._new_zero_map() - def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -1314,8 +1329,8 @@ def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: def count_var_access(self, dtype: LoopyType, name: str, - index: Optional[Tuple[Expression, ...]], - tags: FrozenSet[Tag] + index: Optional[tuple[Expression, ...]], + tags: frozenset[Tag] ) -> ToCountPolynomialMap: count_map = {} @@ -1352,7 +1367,7 @@ def count_var_access(self, return self.new_poly_map(count_map) def map_variable( - self, expr: p.Variable, tags: FrozenSet[Tag] + self, expr: p.Variable, tags: frozenset[Tag] ) -> ToCountPolynomialMap: return self.count_var_access( self.type_inf(expr), expr.name, None, tags) @@ -1360,7 +1375,7 @@ def map_variable( map_tagged_variable = map_variable def map_subscript( - self, expr: p.Subscript, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + self, expr: p.Subscript, tags: frozenset[Tag]) -> ToCountPolynomialMap: return (self.count_var_access(self.type_inf(expr), expr.aggregate.name, expr.index, tags) @@ -1371,7 +1386,7 @@ def map_subscript( # {{{ AccessFootprintGatherer -FootprintsT = Dict[str, isl.Set] +FootprintsT = dict[str, isl.Set] class AccessFootprintGatherer(CombineMapper): @@ -1600,7 +1615,7 @@ def mult_grid_factor(used_axes, sizes): def count_inames_domain( - knl: LoopKernel, inames: FrozenSet[str]) -> GuardedPwQPolynomial: + knl: LoopKernel, inames: frozenset[str]) -> GuardedPwQPolynomial: space = get_kernel_parameter_space(knl) if not inames: return add_assumptions_guard(knl, @@ -1831,7 +1846,7 @@ def get_op_map( assert entrypoint in t_unit.entrypoints from loopy.preprocess import infer_unknown_types, preprocess_program - program = preprocess_program(program) + program = preprocess_program(t_unit) from loopy.match import parse_match within = parse_match(within) @@ -2184,7 +2199,7 @@ def get_synchronization_map( entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import infer_unknown_types, preprocess_program t_unit = preprocess_program(t_unit) @@ -2203,7 +2218,7 @@ def get_synchronization_map( def _gather_access_footprints_for_single_kernel( kernel: LoopKernel, ignore_uncountable: bool - ) -> Tuple[FootprintsT, FootprintsT]: + ) -> tuple[FootprintsT, FootprintsT]: write_footprints = [] read_footprints = [] @@ -2296,8 +2311,8 @@ def gather_access_footprint_bytes( nonlinear indices) """ - from loopy.preprocess import infer_unknown_types, preprocess_program - kernel = infer_unknown_types(program, expect_completion=True) + from loopy.preprocess import infer_unknown_types + kernel = infer_unknown_types(t_unit, expect_completion=True) fp = gather_access_footprints(t_unit, ignore_uncountable=ignore_uncountable) diff --git a/test/test_statistics.py b/test/test_statistics.py index 35e49e0e8..07633171e 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1547,8 +1547,9 @@ class MyCostTagSum(Tag): def test_op_taggedexpression(): + from pymbolic.primitives import Subscript, Sum, Variable + from loopy.symbolic import TaggedExpression - from pymbolic.primitives import Subscript, Variable, Sum n = 500