Skip to content

Commit

Permalink
misc updates
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Nov 7, 2024
1 parent cc7efdc commit 56527b4
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 164 deletions.
93 changes: 66 additions & 27 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ def __mul__(self, other: GuardedPwQPolynomial) -> ToCountMap[CountT]:
__rmul__ = __mul__

def __getitem__(self, index: Countable) -> CountT:
print("HERE", index, hash(index))

if index not in self.count_map:
for k, v in self.count_map.items():
print(f" {k=}, {hash(k.dtype) == hash(index.dtype)} {hash(k)}")
# assert hash(k) == hash(k.dtype), k
# print(f"{self.count_map=}")
return self.count_map[index]

def __repr__(self) -> str:
Expand Down Expand Up @@ -653,6 +660,8 @@ class OpType(Enum):
.. attribute:: POW
.. attribute:: SHIFT
.. attribute:: BITWISE
.. attribute:: MAXMIN
.. attribute:: SPECIAL_FUNC
"""
ADD = enum_auto()
MUL = enum_auto()
Expand Down Expand Up @@ -705,13 +714,16 @@ class Op:
kernel_name: str | None = None
tags: frozenset[Tag] = frozenset()

def __post_init__(self):
if self.dtype is not None:
from loopy.types import to_loopy_type
object.__setattr__(self, "dtype", to_loopy_type(self.dtype))

assert isinstance(self.op_type, OpType) or self.op_type is None, self.op_type

def __repr__(self):
if self.kernel_name is not None:
return (f"Op({self.dtype}, {self.name}, {self.count_granularity},"
f' "{self.kernel_name}", {self.tags})')
else:
return f"Op({self.dtype}, {self.name}, " + \
f"{self.count_granularity}, {self.tags})"
return (f"Op({self.dtype}, {self.op_type}, {self.count_granularity},"
f' "{self.kernel_name if self.kernel_name is not None else ""}", {self.tags})')

# }}}

Expand Down Expand Up @@ -797,9 +809,9 @@ class MemAccess:
"""

address_space: Optional[AddressSpace] = None
dtype: Optional[LoopyType] = None
lid_strides: Optional[Mapping[int, Expression]] = None
gid_strides: Optional[Mapping[int, Expression]] = None
dtype: Optional[LoopyType] = None
read_write: Optional[AccessDirection] = None
variable: Optional[str] = None

Expand All @@ -808,6 +820,21 @@ class MemAccess:
kernel_name: Optional[str] = None
tags: frozenset[Tag] = frozenset()

def __post_init__(self):
if self.dtype is not None:
from loopy.types import to_loopy_type
object.__setattr__(self, "dtype", to_loopy_type(self.dtype))

from immutabledict import immutabledict

if isinstance(self.lid_strides, dict):
object.__setattr__(self, "lid_strides", immutabledict(self.lid_strides))

if isinstance(self.gid_strides, dict):
object.__setattr__(self, "gid_strides", immutabledict(self.gid_strides))

assert self.address_space is None or isinstance(self.address_space, AddressSpace)

@property
def mtype(self) -> str:
from warnings import warn
Expand All @@ -832,26 +859,26 @@ def direction(self) -> str:

if self.read_write == AccessDirection.READ:
return "read"
elif self.address_space == AccessDirection.WRITE:
elif self.read_write == AccessDirection.WRITE:
return "write"
else:
raise ValueError(f"unexpected read_write: '{self.read_write}'")

def __repr__(self):
# Record.__repr__ overridden for consistent ordering and conciseness
return "MemAccess({}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format(
self.address_space,
self.dtype,
None if self.lid_strides is None else dict(
sorted(self.lid_strides.items())),
None if self.gid_strides is None else dict(
sorted(self.gid_strides.items())),
self.read_write,
self.variable,
"None" if not self.variable_tags else str(self.variable_tags),
self.count_granularity,
repr(self.kernel_name),
self.tags)
# def __repr__(self):
# # Record.__repr__ overridden for consistent ordering and conciseness
# return "MemAccess({}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format(
# self.address_space,
# self.dtype,
# None if self.lid_strides is None else dict(
# sorted(self.lid_strides.items())),
# None if self.gid_strides is None else dict(
# sorted(self.gid_strides.items())),
# self.read_write,
# self.variable,
# "None" if not self.variable_tags else str(self.variable_tags),
# self.count_granularity,
# repr(self.kernel_name),
# self.tags)

# }}}

Expand Down Expand Up @@ -1324,7 +1351,7 @@ def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap:
else:
return super().map_call(expr, tags)

# local_mem_count_granularity = CountGranularity.SUBGROUP
local_mem_count_granularity = CountGranularity.SUBGROUP

def count_var_access(self,
dtype: LoopyType,
Expand Down Expand Up @@ -1354,12 +1381,24 @@ def count_var_access(self,
lid_strides, gid_strides = _get_lid_and_gid_strides(
self.knl, array, index_tuple)

# print(MemAccess(
# address_space=array.address_space,
# dtype=dtype,
# tags=tags,
# lid_strides=lid_strides,
# gid_strides=gid_strides,
# variable=name,
# count_granularity=self.local_mem_count_granularity,
# kernel_name=self.knl.name))

from immutabledict import immutabledict

count_map[MemAccess(
address_space=array.address_space,
dtype=dtype,
tags=tags,
lid_strides=lid_strides,
gid_strides=gid_strides,
lid_strides=immutabledict(lid_strides),
gid_strides=immutabledict(gid_strides),
variable=name,
count_granularity=self.local_mem_count_granularity,
kernel_name=self.knl.name)] = self.one
Expand Down Expand Up @@ -1846,7 +1885,7 @@ def get_op_map(
assert entrypoint in t_unit.entrypoints

from loopy.preprocess import infer_unknown_types, preprocess_program
program = preprocess_program(t_unit)
t_unit = preprocess_program(t_unit)

from loopy.match import parse_match
within = parse_match(within)
Expand Down
2 changes: 2 additions & 0 deletions loopy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def __init__(self, dtype: np.dtype):
if dtype == object: # noqa: E721
raise TypeError("loopy does not directly support object arrays")

# Normalize due to https://stackoverflow.com/questions/35293672/why-do-these-dtypes-compare-equal-but-hash-different
self.dtype = np.dtype(dtype)

def __hash__(self) -> int:
# print("hash", repr(self.dtype), hash(self.dtype)== hash(np.float32))
return hash(self.dtype)

def update_persistent_hash(self, key_hash, key_builder):
Expand Down
Loading

0 comments on commit 56527b4

Please sign in to comment.