Skip to content

Commit

Permalink
fix(?) MemoryAccessCounter
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Nov 8, 2024
1 parent f461722 commit 48b6b94
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 86 deletions.
165 changes: 103 additions & 62 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from functools import cached_property, partial
from typing import Any, Callable, Generic, Iterable, Mapping, Optional, TypeVar, Union

from immutabledict import immutabledict

import islpy as isl
import pymbolic.primitives as p
from islpy import PwQPolynomial, dim_type
Expand Down Expand Up @@ -256,13 +258,6 @@ def __mul__(self, other: GuardedPwQPolynomial) -> ToCountMap[CountT]:
__rmul__ = __mul__

def __getitem__(self, index: Countable) -> CountT:
print("HERE", index, hash(index))

if index not in self.count_map:
for k, v in self.count_map.items():
print(f" {k=}, {hash(k.dtype) == hash(index.dtype)} {hash(k)}")
# assert hash(k) == hash(k.dtype), k
# print(f"{self.count_map=}")
return self.count_map[index]

def __repr__(self) -> str:
Expand Down Expand Up @@ -623,7 +618,7 @@ def subst_into_to_count_map(
# {{{ CountGranularity

class CountGranularity(Enum):
"""Strings specifying whether an operation should be counted once per
"""Specify whether an operation should be counted once per
*work-item*, *sub-group*, or *work-group*.
.. attribute:: WORKITEM
Expand Down Expand Up @@ -721,9 +716,18 @@ def __post_init__(self):

assert isinstance(self.op_type, OpType) or self.op_type is None, self.op_type

if not (self.count_granularity is None
or isinstance(self.count_granularity, CountGranularity)):
raise ValueError(
f"unexpected count_granularity: '{self.count_granularity}'")

def __repr__(self):
return (f"Op({self.dtype}, {self.op_type}, {self.count_granularity},"
f' "{self.kernel_name if self.kernel_name is not None else ""}", {self.tags})')
if self.kernel_name is not None:
return (f"Op({self.dtype}, {self.name}, {self.count_granularity},"
f' "{self.kernel_name}", {self.tags})')
else:
return f"Op({self.dtype}, {self.name}, " + \
f"{self.count_granularity}, {self.tags})"

# }}}

Expand All @@ -745,7 +749,7 @@ class MemAccess:
.. attribute:: address_space
A :class:`str` that specifies the memory type accessed as **global**
A :class:`AddressSpace` that specifies the memory type accessed as **global**
or **local**
.. attribute:: dtype
Expand Down Expand Up @@ -825,15 +829,19 @@ def __post_init__(self):
from loopy.types import to_loopy_type
object.__setattr__(self, "dtype", to_loopy_type(self.dtype))

from immutabledict import immutabledict

if isinstance(self.lid_strides, dict):
object.__setattr__(self, "lid_strides", immutabledict(self.lid_strides))

if isinstance(self.gid_strides, dict):
object.__setattr__(self, "gid_strides", immutabledict(self.gid_strides))

assert self.address_space is None or isinstance(self.address_space, AddressSpace)
assert (self.address_space is None
or isinstance(self.address_space, AddressSpace))

if not (self.count_granularity is None
or isinstance(self.count_granularity, CountGranularity)):
raise ValueError(
f"unexpected count_granularity: '{self.count_granularity}'")

@property
def mtype(self) -> str:
Expand Down Expand Up @@ -911,10 +919,6 @@ class Sync:
kernel_name: Optional[str] = None
tags: frozenset[Tag] = frozenset()

def __repr__(self):
# Record.__repr__ overridden for consistent ordering and conciseness
return f"Sync({self.sync_kind}, {self.kernel_name}, {self.tags})"

# }}}


Expand Down Expand Up @@ -1332,7 +1336,7 @@ def get_iname_strides(
# }}}


# {{{ MemAccessCounterBase
# {{{ MemAccessCounter

class MemAccessCounter(CounterBase):
def map_sub_array_ref(
Expand All @@ -1351,59 +1355,96 @@ def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap:
else:
return super().map_call(expr, tags)

local_mem_count_granularity = CountGranularity.SUBGROUP

def count_var_access(self,
dtype: LoopyType,
name: str,
index: Optional[tuple[Expression, ...]],
tags: frozenset[Tag]
) -> ToCountPolynomialMap:
count_map = {}

from loopy.kernel.data import TemporaryVariable
array = self.knl.get_var_descriptor(name)

if index is None:
# no subscript
count_map[MemAccess(
if isinstance(array, TemporaryVariable) and (
array.address_space == AddressSpace.LOCAL):
# local memory access
local_mem_count_granularity = CountGranularity.SUBGROUP

if index is None:
return self.new_poly_map({MemAccess(
address_space=AddressSpace.LOCAL,
tags=tags,
dtype=dtype,
count_granularity=self.local_mem_count_granularity,
kernel_name=self.knl.name)] = self.one
return self.new_poly_map(count_map)

# could be tuple or scalar index
index_tuple = index
if not isinstance(index_tuple, tuple):
index_tuple = (index_tuple,)

lid_strides, gid_strides = _get_lid_and_gid_strides(
self.knl, array, index_tuple)

# print(MemAccess(
# address_space=array.address_space,
# dtype=dtype,
# tags=tags,
# lid_strides=lid_strides,
# gid_strides=gid_strides,
# variable=name,
# count_granularity=self.local_mem_count_granularity,
# kernel_name=self.knl.name))

from immutabledict import immutabledict

count_map[MemAccess(
address_space=array.address_space,
dtype=dtype,
tags=tags,
lid_strides=immutabledict(lid_strides),
gid_strides=immutabledict(gid_strides),
variable=name,
count_granularity=self.local_mem_count_granularity,
kernel_name=self.knl.name)] = self.one

return self.new_poly_map(count_map)
tags=tags,
count_granularity=local_mem_count_granularity,
kernel_name=self.knl.name): self.one})

# could be tuple or scalar index
index_tuple = index
if not isinstance(index_tuple, tuple):
index_tuple = (index_tuple,)

lid_strides, gid_strides = _get_lid_and_gid_strides(
self.knl, array, index_tuple)

return self.new_poly_map({MemAccess(
address_space=array.address_space,
dtype=dtype,
tags=tags,
lid_strides=immutabledict(lid_strides),
gid_strides=immutabledict(gid_strides),
variable=name,
count_granularity=self.local_mem_count_granularity,
kernel_name=self.knl.name): self.one})

elif (isinstance(array, TemporaryVariable) and (
array.address_space == AddressSpace.GLOBAL)) or (
isinstance(array, lp.ArrayArg)):
if index is None:
return self.new_poly_map({MemAccess(
address_space=AddressSpace.GLOBAL,
dtype=dtype,
lid_strides=immutabledict({}),
gid_strides=immutabledict({}),
variable=name,
tags=tags,
count_granularity=CountGranularity.WORKITEM,
kernel_name=self.knl.name): self.one})

# could be tuple or scalar index
index_tuple = index
if not isinstance(index_tuple, tuple):
index_tuple = (index_tuple,)

lid_strides, gid_strides = _get_lid_and_gid_strides(
self.knl, array, index_tuple)
# Account for broadcasts once per subgroup
count_granularity = CountGranularity.WORKITEM if (
# if the stride in lid.0 is known
0 in lid_strides
and
# it is nonzero
lid_strides[0] != 0
) else CountGranularity.SUBGROUP

try:
# var_tags = expr.aggregate.tags # FIXME
var_tags = frozenset()
except AttributeError:
var_tags = frozenset()

return self.new_poly_map({MemAccess(
address_space=AddressSpace.GLOBAL,
dtype=dtype,
lid_strides=immutabledict(lid_strides),
gid_strides=immutabledict(gid_strides),
variable=name,
tags=tags,
variable_tags=var_tags,
count_granularity=count_granularity,
kernel_name=self.knl.name,
): self.one}
)
else:
return self._new_zero_map()

def map_variable(
self, expr: p.Variable, tags: frozenset[Tag]
Expand Down
20 changes: 4 additions & 16 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def copy(self, *, name=None, tags=None):
return TaggedVariable(name, tags)


@p.expr_dataclass(init=False)
@p.expr_dataclass()
class TaggedExpression(LoopyExpressionBase):
"""
Represents a frozenset of tags attached to an :attr:`expr`.
Expand All @@ -720,22 +720,10 @@ class TaggedExpression(LoopyExpressionBase):

init_arg_names = ("tags", "expr")

def __init__(self, tags, expr):
self.tags = tags
self.expr = expr

def __getinitargs__(self):
return (self.tags, self.expr)

def get_hash(self):
return hash((self.__class__, self.tags, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.tags == self.tags
and other.expr == self.expr)
tags: frozenset[Tag]
expr: ExpressionT

mapper_method = intern("map_tagged_expression")
mapper_method = "map_tagged_expression"


@p.expr_dataclass(init=False)
Expand Down
26 changes: 18 additions & 8 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from pytools import div_ceil

import loopy as lp
from loopy.statistics import CountGranularity as CG, OpType, AccessDirection, AddressSpace
from loopy.statistics import (
AccessDirection,
AddressSpace,
CountGranularity as CG,
OpType,
)
from loopy.types import to_loopy_type
from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa

Expand Down Expand Up @@ -417,9 +422,11 @@ def test_mem_access_counter_reduction():
# uniform: (count-per-sub-group)*n_subgroups
assert f32s == (n*ell)*n_subgroups

ld_bytes = mem_map.filter_by(mtype=["global"], read_write=[AccessDirection.READ]
ld_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL],
read_write=[AccessDirection.READ]
).to_bytes().eval_and_sum(params)
st_bytes = mem_map.filter_by(mtype=["global"], read_write=[AccessDirection.WRITE]
st_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL],
read_write=[AccessDirection.WRITE]
).to_bytes().eval_and_sum(params)
assert ld_bytes == 4*f32l
assert st_bytes == 4*f32s
Expand Down Expand Up @@ -543,8 +550,9 @@ def test_mem_access_counter_special_ops():
assert f32 == (n*m*ell)*n_subgroups
assert f64 == (n*m)*n_subgroups

filtered_map = mem_map.filter_by(read_write=[AccessDirection.READ], variable=["a", "g"],
count_granularity=[CG.SUBGROUP])
filtered_map = mem_map.filter_by(read_write=[AccessDirection.READ],
variable=["a", "g"],
count_granularity=[CG.SUBGROUP])
tot = filtered_map.eval_and_sum(params)

# uniform: (count-per-sub-group)*n_subgroups
Expand Down Expand Up @@ -959,7 +967,7 @@ def test_mem_access_counter_global_temps():

# Count global accesses
global_accesses = mem_map.filter_by(
mtype=["global"]).sum().eval_with_dict(params)
address_space=[AddressSpace.GLOBAL]).sum().eval_with_dict(params)

assert global_accesses == n*m

Expand Down Expand Up @@ -1361,9 +1369,11 @@ def test_summations_and_filters():

# ignore stride and variable names in this map
reduced_map = mem_map.group_by("mtype", "dtype", "direction")
f32lall = reduced_map[lp.MemAccess("global", np.float32, read_write=AccessDirection.READ)
f32lall = reduced_map[lp.MemAccess("global", np.float32,
read_write=AccessDirection.READ)
].eval_with_dict(params)
f64lall = reduced_map[lp.MemAccess("global", np.float64, read_write=AccessDirection.READ)
f64lall = reduced_map[lp.MemAccess("global", np.float64,
read_write=AccessDirection.READ)
].eval_with_dict(params)

# uniform: (count-per-sub-group)*n_subgroups
Expand Down

0 comments on commit 48b6b94

Please sign in to comment.