diff --git a/loopy/statistics.py b/loopy/statistics.py index a577b58ac..f7d9b85a0 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -33,6 +33,8 @@ from functools import cached_property, partial from typing import Any, Callable, Generic, Iterable, Mapping, Optional, TypeVar, Union +from immutabledict import immutabledict + import islpy as isl import pymbolic.primitives as p from islpy import PwQPolynomial, dim_type @@ -256,13 +258,6 @@ def __mul__(self, other: GuardedPwQPolynomial) -> ToCountMap[CountT]: __rmul__ = __mul__ def __getitem__(self, index: Countable) -> CountT: - print("HERE", index, hash(index)) - - if index not in self.count_map: - for k, v in self.count_map.items(): - print(f" {k=}, {hash(k.dtype) == hash(index.dtype)} {hash(k)}") - # assert hash(k) == hash(k.dtype), k - # print(f"{self.count_map=}") return self.count_map[index] def __repr__(self) -> str: @@ -623,7 +618,7 @@ def subst_into_to_count_map( # {{{ CountGranularity class CountGranularity(Enum): - """Strings specifying whether an operation should be counted once per + """Specify whether an operation should be counted once per *work-item*, *sub-group*, or *work-group*. .. attribute:: WORKITEM @@ -721,9 +716,18 @@ def __post_init__(self): assert isinstance(self.op_type, OpType) or self.op_type is None, self.op_type + if not (self.count_granularity is None + or isinstance(self.count_granularity, CountGranularity)): + raise ValueError( + f"unexpected count_granularity: '{self.count_granularity}'") + def __repr__(self): - return (f"Op({self.dtype}, {self.op_type}, {self.count_granularity}," - f' "{self.kernel_name if self.kernel_name is not None else ""}", {self.tags})') + if self.kernel_name is not None: + return (f"Op({self.dtype}, {self.name}, {self.count_granularity}," + f' "{self.kernel_name}", {self.tags})') + else: + return f"Op({self.dtype}, {self.name}, " + \ + f"{self.count_granularity}, {self.tags})" # }}} @@ -745,7 +749,7 @@ class MemAccess: .. attribute:: address_space - A :class:`str` that specifies the memory type accessed as **global** + A :class:`AddressSpace` that specifies the memory type accessed as **global** or **local** .. attribute:: dtype @@ -825,15 +829,19 @@ def __post_init__(self): from loopy.types import to_loopy_type object.__setattr__(self, "dtype", to_loopy_type(self.dtype)) - from immutabledict import immutabledict - if isinstance(self.lid_strides, dict): object.__setattr__(self, "lid_strides", immutabledict(self.lid_strides)) if isinstance(self.gid_strides, dict): object.__setattr__(self, "gid_strides", immutabledict(self.gid_strides)) - assert self.address_space is None or isinstance(self.address_space, AddressSpace) + assert (self.address_space is None + or isinstance(self.address_space, AddressSpace)) + + if not (self.count_granularity is None + or isinstance(self.count_granularity, CountGranularity)): + raise ValueError( + f"unexpected count_granularity: '{self.count_granularity}'") @property def mtype(self) -> str: @@ -911,10 +919,6 @@ class Sync: kernel_name: Optional[str] = None tags: frozenset[Tag] = frozenset() - def __repr__(self): - # Record.__repr__ overridden for consistent ordering and conciseness - return f"Sync({self.sync_kind}, {self.kernel_name}, {self.tags})" - # }}} @@ -1332,7 +1336,7 @@ def get_iname_strides( # }}} -# {{{ MemAccessCounterBase +# {{{ MemAccessCounter class MemAccessCounter(CounterBase): def map_sub_array_ref( @@ -1351,59 +1355,96 @@ def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: else: return super().map_call(expr, tags) - local_mem_count_granularity = CountGranularity.SUBGROUP - def count_var_access(self, dtype: LoopyType, name: str, index: Optional[tuple[Expression, ...]], tags: frozenset[Tag] ) -> ToCountPolynomialMap: - count_map = {} - + from loopy.kernel.data import TemporaryVariable array = self.knl.get_var_descriptor(name) - if index is None: - # no subscript - count_map[MemAccess( + if isinstance(array, TemporaryVariable) and ( + array.address_space == AddressSpace.LOCAL): + # local memory access + local_mem_count_granularity = CountGranularity.SUBGROUP + + if index is None: + return self.new_poly_map({MemAccess( address_space=AddressSpace.LOCAL, - tags=tags, dtype=dtype, - count_granularity=self.local_mem_count_granularity, - kernel_name=self.knl.name)] = self.one - return self.new_poly_map(count_map) - - # could be tuple or scalar index - index_tuple = index - if not isinstance(index_tuple, tuple): - index_tuple = (index_tuple,) - - lid_strides, gid_strides = _get_lid_and_gid_strides( - self.knl, array, index_tuple) - - # print(MemAccess( - # address_space=array.address_space, - # dtype=dtype, - # tags=tags, - # lid_strides=lid_strides, - # gid_strides=gid_strides, - # variable=name, - # count_granularity=self.local_mem_count_granularity, - # kernel_name=self.knl.name)) - - from immutabledict import immutabledict - - count_map[MemAccess( - address_space=array.address_space, - dtype=dtype, - tags=tags, - lid_strides=immutabledict(lid_strides), - gid_strides=immutabledict(gid_strides), - variable=name, - count_granularity=self.local_mem_count_granularity, - kernel_name=self.knl.name)] = self.one - - return self.new_poly_map(count_map) + tags=tags, + count_granularity=local_mem_count_granularity, + kernel_name=self.knl.name): self.one}) + + # could be tuple or scalar index + index_tuple = index + if not isinstance(index_tuple, tuple): + index_tuple = (index_tuple,) + + lid_strides, gid_strides = _get_lid_and_gid_strides( + self.knl, array, index_tuple) + + return self.new_poly_map({MemAccess( + address_space=array.address_space, + dtype=dtype, + tags=tags, + lid_strides=immutabledict(lid_strides), + gid_strides=immutabledict(gid_strides), + variable=name, + count_granularity=self.local_mem_count_granularity, + kernel_name=self.knl.name): self.one}) + + elif (isinstance(array, TemporaryVariable) and ( + array.address_space == AddressSpace.GLOBAL)) or ( + isinstance(array, lp.ArrayArg)): + if index is None: + return self.new_poly_map({MemAccess( + address_space=AddressSpace.GLOBAL, + dtype=dtype, + lid_strides=immutabledict({}), + gid_strides=immutabledict({}), + variable=name, + tags=tags, + count_granularity=CountGranularity.WORKITEM, + kernel_name=self.knl.name): self.one}) + + # could be tuple or scalar index + index_tuple = index + if not isinstance(index_tuple, tuple): + index_tuple = (index_tuple,) + + lid_strides, gid_strides = _get_lid_and_gid_strides( + self.knl, array, index_tuple) + # Account for broadcasts once per subgroup + count_granularity = CountGranularity.WORKITEM if ( + # if the stride in lid.0 is known + 0 in lid_strides + and + # it is nonzero + lid_strides[0] != 0 + ) else CountGranularity.SUBGROUP + + try: + # var_tags = expr.aggregate.tags # FIXME + var_tags = frozenset() + except AttributeError: + var_tags = frozenset() + + return self.new_poly_map({MemAccess( + address_space=AddressSpace.GLOBAL, + dtype=dtype, + lid_strides=immutabledict(lid_strides), + gid_strides=immutabledict(gid_strides), + variable=name, + tags=tags, + variable_tags=var_tags, + count_granularity=count_granularity, + kernel_name=self.knl.name, + ): self.one} + ) + else: + return self._new_zero_map() def map_variable( self, expr: p.Variable, tags: frozenset[Tag] diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 4b64ffaeb..0e59fbcde 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -703,7 +703,7 @@ def copy(self, *, name=None, tags=None): return TaggedVariable(name, tags) -@p.expr_dataclass(init=False) +@p.expr_dataclass() class TaggedExpression(LoopyExpressionBase): """ Represents a frozenset of tags attached to an :attr:`expr`. @@ -720,22 +720,10 @@ class TaggedExpression(LoopyExpressionBase): init_arg_names = ("tags", "expr") - def __init__(self, tags, expr): - self.tags = tags - self.expr = expr - - def __getinitargs__(self): - return (self.tags, self.expr) - - def get_hash(self): - return hash((self.__class__, self.tags, self.expr)) - - def is_equal(self, other): - return (other.__class__ == self.__class__ - and other.tags == self.tags - and other.expr == self.expr) + tags: frozenset[Tag] + expr: ExpressionT - mapper_method = intern("map_tagged_expression") + mapper_method = "map_tagged_expression" @p.expr_dataclass(init=False) diff --git a/test/test_statistics.py b/test/test_statistics.py index 967092cde..42e412a4e 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -31,7 +31,12 @@ from pytools import div_ceil import loopy as lp -from loopy.statistics import CountGranularity as CG, OpType, AccessDirection, AddressSpace +from loopy.statistics import ( + AccessDirection, + AddressSpace, + CountGranularity as CG, + OpType, +) from loopy.types import to_loopy_type from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa @@ -417,9 +422,11 @@ def test_mem_access_counter_reduction(): # uniform: (count-per-sub-group)*n_subgroups assert f32s == (n*ell)*n_subgroups - ld_bytes = mem_map.filter_by(mtype=["global"], read_write=[AccessDirection.READ] + ld_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.READ] ).to_bytes().eval_and_sum(params) - st_bytes = mem_map.filter_by(mtype=["global"], read_write=[AccessDirection.WRITE] + st_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.WRITE] ).to_bytes().eval_and_sum(params) assert ld_bytes == 4*f32l assert st_bytes == 4*f32s @@ -543,8 +550,9 @@ def test_mem_access_counter_special_ops(): assert f32 == (n*m*ell)*n_subgroups assert f64 == (n*m)*n_subgroups - filtered_map = mem_map.filter_by(read_write=[AccessDirection.READ], variable=["a", "g"], - count_granularity=[CG.SUBGROUP]) + filtered_map = mem_map.filter_by(read_write=[AccessDirection.READ], + variable=["a", "g"], + count_granularity=[CG.SUBGROUP]) tot = filtered_map.eval_and_sum(params) # uniform: (count-per-sub-group)*n_subgroups @@ -959,7 +967,7 @@ def test_mem_access_counter_global_temps(): # Count global accesses global_accesses = mem_map.filter_by( - mtype=["global"]).sum().eval_with_dict(params) + address_space=[AddressSpace.GLOBAL]).sum().eval_with_dict(params) assert global_accesses == n*m @@ -1361,9 +1369,11 @@ def test_summations_and_filters(): # ignore stride and variable names in this map reduced_map = mem_map.group_by("mtype", "dtype", "direction") - f32lall = reduced_map[lp.MemAccess("global", np.float32, read_write=AccessDirection.READ) + f32lall = reduced_map[lp.MemAccess("global", np.float32, + read_write=AccessDirection.READ) ].eval_with_dict(params) - f64lall = reduced_map[lp.MemAccess("global", np.float64, read_write=AccessDirection.READ) + f64lall = reduced_map[lp.MemAccess("global", np.float64, + read_write=AccessDirection.READ) ].eval_with_dict(params) # uniform: (count-per-sub-group)*n_subgroups