Skip to content

Commit

Permalink
add WithTag
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Oct 11, 2022
1 parent fdfa2ed commit e41625f
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 7 deletions.
4 changes: 2 additions & 2 deletions doc/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ information provided. Now we will count the operations:

>>> op_map = lp.get_op_map(knl, subgroup_size=32)
>>> print(op_map)
Op(np:dtype('float32'), add, subgroup, "stats_knl"): ...
Op(np:dtype('float32'), add, subgroup, "stats_knl", None): ...

Each line of output will look roughly like::

Expand Down Expand Up @@ -1628,7 +1628,7 @@ together into keys containing only the specified fields:

>>> op_map_dtype = op_map.group_by('dtype')
>>> print(op_map_dtype)
Op(np:dtype('float32'), None, None): ...
Op(np:dtype('float32'), None, None, None): ...
<BLANKLINE>
>>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32)
... ].eval_with_dict(param_dict)
Expand Down
22 changes: 17 additions & 5 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,14 @@ class Op(ImmutableRecord):
A :class:`str` representing the kernel name where the operation occurred.
.. attribute:: tags
A :class:`frozenset` of tags to the operation.
"""

def __init__(self, dtype=None, name=None, count_granularity=None,
kernel_name=None):
kernel_name=None, tags=None):
if count_granularity not in CountGranularity.ALL+[None]:
raise ValueError("Op.__init__: count_granularity '%s' is "
"not allowed. count_granularity options: %s"
Expand All @@ -651,15 +655,17 @@ def __init__(self, dtype=None, name=None, count_granularity=None,

super().__init__(dtype=dtype, name=name,
count_granularity=count_granularity,
kernel_name=kernel_name)
kernel_name=kernel_name,
tags=tags)

def __repr__(self):
# Record.__repr__ overridden for consistent ordering and conciseness
if self.kernel_name is not None:
return (f"Op({self.dtype}, {self.name}, {self.count_granularity},"
f' "{self.kernel_name}")')
f' "{self.kernel_name}", {self.tags})')
else:
return f"Op({self.dtype}, {self.name}, {self.count_granularity})"
return f"Op({self.dtype}, {self.name}, " + \
f"{self.count_granularity}, {self.tags})"

# }}}

Expand Down Expand Up @@ -724,7 +730,7 @@ class MemAccess(ImmutableRecord):
work-group executes on a single compute unit with all work-items within
the work-group sharing local memory. A sub-group is an
implementation-dependent grouping of work-items within a work-group,
analagous to an NVIDIA CUDA warp.
analogous to an NVIDIA CUDA warp.
.. attribute:: kernel_name
Expand Down Expand Up @@ -922,6 +928,12 @@ def map_constant(self, expr):
map_tagged_variable = map_constant
map_variable = map_constant

def map_with_tag(self, expr):
opmap = self.rec(expr.expr)
for op in opmap.count_map:
op.tags = expr.tags
return opmap

def map_call(self, expr):
from loopy.symbolic import ResolvedFunction
assert isinstance(expr.function, ResolvedFunction)
Expand Down
46 changes: 46 additions & 0 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@
# {{{ mappers with support for loopy-specific primitives

class IdentityMapperMixin:
def map_with_tag(self, expr, *args, **kwargs):
new_expr = self.rec(expr.expr, *args, **kwargs)
return WithTag(expr.tags, new_expr)

def map_literal(self, expr, *args, **kwargs):
return expr

Expand Down Expand Up @@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr):


class WalkMapperMixin:
def map_with_tag(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return

self.rec(expr.expr, *args, **kwargs)

def map_literal(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)

Expand Down Expand Up @@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):


class CombineMapper(CombineMapperBase):
def map_with_tag(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

def map_reduction(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

Expand All @@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,


class StringifyMapper(StringifyMapperBase):
def map_with_tag(self, expr, *args):
from pymbolic.mapper.stringifier import PREC_NONE
return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}"

def map_literal(self, expr, *args):
return expr.s

Expand Down Expand Up @@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs):
def map_loopy_function_identifier(self, expr, *args, **kwargs):
return set()

def map_with_tag(self, expr, *args, **kwargs):
deps = self.rec(expr.expr, *args, **kwargs)
return deps

def map_sub_array_ref(self, expr, *args, **kwargs):
deps = self.rec(expr.subscript, *args, **kwargs)
return deps - set(expr.swept_inames)
Expand Down Expand Up @@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None):
mapper_method = intern("map_tagged_variable")


class WithTag(LoopyExpressionBase):
"""
Represents a frozenset of tags attached to an :attr:`expr`.
"""

init_arg_names = ("tags", "expr")

def __init__(self, tags, expr):
self.tags = tags
self.expr = expr

def __getinitargs__(self):
return (self.tags, self.expr)

def get_hash(self):
return hash((self.__class__, self.tags, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.tags == self.tags
and other.expr == self.expr)

mapper_method = intern("map_with_tag")


class Reduction(LoopyExpressionBase):
"""
Represents a reduction operation on :attr:`expr` across :attr:`inames`.
Expand Down
58 changes: 58 additions & 0 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,64 @@ def test_no_loop_ops():
assert f64_mul == 1


from pytools.tag import Tag


class MyCostTag1(Tag):
pass


class MyCostTag2(Tag):
pass


class MyCostTagSum(Tag):
pass


def test_op_with_tag():
from loopy.symbolic import WithTag
from pymbolic.primitives import Subscript, Variable, Sum

n = 500

knl = lp.make_kernel(
"{[i]: 0<=i<n}",
[
lp.Assignment("c[i]", WithTag(frozenset((MyCostTagSum(),)),
Sum(
(WithTag(frozenset((MyCostTag1(),)),
Subscript(Variable("a"), Variable("i"))),
WithTag(frozenset((MyCostTag2(),)),
Subscript(Variable("b"), Variable("i")))))))
])

knl = lp.add_dtypes(knl, {"a": np.float64, "b": np.float64})

params = {"n": n}

op_map = lp.get_op_map(knl, subgroup_size=32)

f64_add = op_map.filter_by(dtype=[np.float64]).eval_and_sum(params)
assert f64_add == n

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTagSum(),))]).eval_and_sum(params)
assert f64_add == n

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTag1(),))]).eval_and_sum(params)
assert f64_add == 0

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTag2(),))]).eval_and_sum(params)
assert f64_add == 0

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTag2(), MyCostTagSum()))]).eval_and_sum(params)
assert f64_add == 0


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit e41625f

Please sign in to comment.