Skip to content

Commit

Permalink
fix (?) remaining mypy errors, somewhat sketchy in parts
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Nov 11, 2024
1 parent b54217d commit a693fb0
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 59 deletions.
137 changes: 79 additions & 58 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@
from dataclasses import dataclass, replace
from enum import Enum, auto as enum_auto
from functools import cached_property, partial
from typing import Any, Callable, Generic, Iterable, Mapping, Optional, TypeVar, Union
from typing import (
Any,
Callable,
Generic,
Iterable,
Mapping,
Sequence,
Type,
TypeVar,
Union,
cast,
)

from immutabledict import immutabledict

import islpy as isl
import pymbolic.primitives as p
from islpy import PwQPolynomial, dim_type
from pymbolic.mapper import CombineMapper
from pymbolic.typing import ArithmeticExpressionT
from pytools import memoize_method
from pytools.tag import Tag

import loopy as lp
from loopy.diagnostic import LoopyError, warn_with_kernel
from loopy.kernel import LoopKernel
from loopy.kernel.array import ArrayBase
from loopy.kernel.data import AddressSpace, InameImplementationTag, MultiAssignmentBase
from loopy.kernel.function_interface import CallableKernel, InKernelCallable
from loopy.kernel.data import AddressSpace, MultiAssignmentBase
from loopy.kernel.function_interface import CallableKernel
from loopy.kernel.instruction import InstructionBase
from loopy.symbolic import (
CoefficientCollector,
Expand All @@ -56,9 +68,9 @@
TaggedExpression,
flatten,
)
from loopy.translation_unit import TranslationUnit
from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit
from loopy.types import LoopyType
from loopy.typing import Expression
from loopy.typing import Expression, ExpressionT, auto


__doc__ = """
Expand Down Expand Up @@ -277,7 +289,7 @@ def __len__(self) -> int:
return len(self.count_map)

def get(self,
key: Countable, default: Optional[CountT] = None) -> Optional[CountT]:
key: Countable, default: CountT | None = None) -> CountT | None:
return self.count_map.get(key, default)

def items(self):
Expand All @@ -290,7 +302,7 @@ def values(self):
return self.count_map.values()

def copy(
self, count_map: Optional[dict[Countable, CountT]] = None
self, count_map: dict[Countable, CountT] | None = None
) -> ToCountMap[CountT]:
if count_map is None:
count_map = self.count_map
Expand Down Expand Up @@ -686,8 +698,8 @@ class Op:
.. attribute:: count_granularity
A :class:`str` that specifies whether this operation should be counted
once per *work-item*, *sub-group*, or *work-group*. The granularities
A :class:`CountGranularity` that specifies whether this operation should be
counted once per *work-item*, *sub-group*, or *work-group*. The granularities
allowed can be found in :class:`CountGranularity`, and may be accessed,
e.g., as ``CountGranularity.WORKITEM``. A work-item is a single instance
of computation executing on a single processor (think "thread"), a
Expand Down Expand Up @@ -816,16 +828,16 @@ class MemAccess:
A :class:`frozenset` of tags to the operation.
"""

address_space: Optional[AddressSpace] = None
dtype: Optional[LoopyType] = None
lid_strides: Optional[Mapping[int, Expression]] = None
gid_strides: Optional[Mapping[int, Expression]] = None
read_write: Optional[AccessDirection] = None
variable: Optional[str] = None
address_space: AddressSpace | Type[auto] | None = None
dtype: LoopyType | None = None
lid_strides: Mapping[int, Expression] | None = None
gid_strides: Mapping[int, Expression] | None = None
read_write: AccessDirection | None = None
variable: str | None = None

variable_tags: frozenset[Tag] = frozenset()
count_granularity: Optional[CountGranularity] = None
kernel_name: Optional[str] = None
count_granularity: CountGranularity | None = None
kernel_name: str | None = None
tags: frozenset[Tag] = frozenset()

def __post_init__(self):
Expand Down Expand Up @@ -927,8 +939,8 @@ class Sync:
A :class:`frozenset` of tags attached to the synchronization.
"""
sync_kind: Optional[SynchronizationKind] = None
kernel_name: Optional[str] = None
sync_kind: SynchronizationKind | None = None
kernel_name: str | None = None
tags: frozenset[Tag] = frozenset()

def __post_init__(self):
Expand Down Expand Up @@ -1044,7 +1056,7 @@ def map_reduction(
% (type(self).__name__, type(expr).__name__))

def __call__(
self, expr, tags: Optional[frozenset[Tag]] = None
self, expr, tags: frozenset[Tag] | None = None
) -> ToCountPolynomialMap:
if tags is None:
tags = frozenset()
Expand Down Expand Up @@ -1111,7 +1123,8 @@ def map_sum(self, expr: p.Sum, tags: frozenset[Tag]) -> ToCountPolynomialMap:
) + sum(self.rec(child, tags) for child in expr.children)

def map_product(
self, expr: p.Product, tags: frozenset[Tag]) -> ToCountPolynomialMap:
self, expr: p.Product, tags: frozenset[Tag]) \
-> ToCountMap[GuardedPwQPolynomial]:
from pymbolic.primitives import is_zero
assert expr.children
return sum(self.new_poly_map({Op(dtype=self.type_inf(expr),
Expand All @@ -1122,7 +1135,7 @@ def map_product(
kernel_name=self.knl.name): self.one})
+ self.rec(child, tags)
for child in expr.children
if not is_zero(child + 1)) + \
if not is_zero(cast(ArithmeticExpressionT, child) + 1)) + \
self.new_poly_map({Op(dtype=self.type_inf(expr),
op_type=OpType.MUL,
tags=tags,
Expand Down Expand Up @@ -1247,8 +1260,8 @@ def map_floor_div(self, expr):
# {{{ _get_lid_and_gid_strides

def _get_lid_and_gid_strides(
knl: LoopKernel, array: ArrayBase, index: tuple[Expression, ...]
) -> tuple[Mapping[int, Expression], Mapping[int, Expression]]:
knl: LoopKernel, array: ArrayBase, index: tuple[ExpressionT, ...]
) -> tuple[Mapping[int, ExpressionT], Mapping[int, ExpressionT]]:
# find all local and global index tags and corresponding inames
from loopy.symbolic import get_dependencies
my_inames = get_dependencies(index) & knl.all_inames()
Expand Down Expand Up @@ -1285,18 +1298,20 @@ def _get_lid_and_gid_strides(
from loopy.symbolic import simplify_using_aff

def get_iname_strides(
tag_to_iname_dict: Mapping[InameImplementationTag, str]
) -> Mapping[InameImplementationTag, Expression]:
tag_to_iname_dict: Mapping[int, str]
) -> Mapping[int, Expression]:
tag_to_stride_dict = {}

from loopy.kernel.array import ArrayDimImplementationTag

if array.dim_tags is None:
assert len(index) <= 1
dim_tags = (None,) * len(index)
dim_tags: Sequence[ArrayDimImplementationTag | None] = (None,) * len(index)
else:
dim_tags = array.dim_tags

for tag in tag_to_iname_dict:
total_iname_stride = 0
total_iname_stride: Any = 0
# find total stride of this iname for each axis
for idx, axis_tag in zip(index, dim_tags):
# collect index coefficients
Expand All @@ -1305,7 +1320,7 @@ def get_iname_strides(
[tag_to_iname_dict[tag]])(
simplify_using_aff(knl, idx))
except ExpressionNotAffineError:
total_iname_stride = None
total_iname_stride = 0
break

# check if idx contains this iname
Expand All @@ -1321,7 +1336,7 @@ def get_iname_strides(
axis_tag_stride = axis_tag.stride

if axis_tag_stride is lp.auto:
total_iname_stride = None
total_iname_stride = 0
break

elif axis_tag is None:
Expand Down Expand Up @@ -1363,7 +1378,7 @@ def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap:
def count_var_access(self,
dtype: LoopyType,
name: str,
index: Optional[tuple[Expression, ...]],
index: ExpressionT | None,
tags: frozenset[Tag],
var_tags: frozenset[Tag] = frozenset()
) -> ToCountPolynomialMap:
Expand Down Expand Up @@ -1457,10 +1472,12 @@ def map_variable(
def map_subscript(
self, expr: p.Subscript, tags: frozenset[Tag]) -> ToCountPolynomialMap:
try:
var_tags = expr.aggregate.tags
var_tags = expr.aggregate.tags # type: ignore[union-attr]
except AttributeError:
var_tags = frozenset()

assert hasattr(expr.aggregate, "name")

return (self.count_var_access(self.type_inf(expr),
expr.aggregate.name,
expr.index, tags, var_tags)
Expand Down Expand Up @@ -1713,7 +1730,7 @@ def count_inames_domain(

def count_insn_runs(
knl: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
callables_table: ConcreteCallablesTable,
insn: InstructionBase,
count_redundant_work: bool,
disregard_local_axes: bool = False) -> GuardedPwQPolynomial:
Expand All @@ -1738,11 +1755,11 @@ def count_insn_runs(

def _get_insn_count(
knl: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
insn_id: str,
subgroup_size: Optional[int],
callables_table: ConcreteCallablesTable,
insn_id: str | None,
subgroup_size: int | None,
count_redundant_work: bool,
count_granularity: CountGranularity = CountGranularity.WORKITEM
count_granularity: CountGranularity | None = CountGranularity.WORKITEM
) -> GuardedPwQPolynomial:
insn = knl.id_to_insn[insn_id]

Expand Down Expand Up @@ -1813,10 +1830,10 @@ def _get_insn_count(

def _get_op_map_for_single_kernel(
knl: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
callables_table: ConcreteCallablesTable,
count_redundant_work: bool,
count_within_subscripts: bool,
subgroup_size: int, within) -> ToCountPolynomialMap:
subgroup_size: int | None, within) -> ToCountMap[GuardedPwQPolynomial]:

subgroup_size = _process_subgroup_size(knl, subgroup_size)

Expand All @@ -1828,7 +1845,7 @@ def _get_op_map_for_single_kernel(

op_counter = ExpressionOpCounter(knl, callables_table, kernel_rec,
count_within_subscripts)
op_map = op_counter._new_zero_map()
op_map: ToCountMap[GuardedPwQPolynomial] = op_counter._new_zero_map()

from loopy.kernel.instruction import (
Assignment,
Expand All @@ -1843,6 +1860,7 @@ def _get_op_map_for_single_kernel(
if isinstance(insn, (CallInstruction, Assignment)):
ops = op_counter(insn.assignees) + op_counter(insn.expression)
for key, val in ops.count_map.items():
key = cast(Op, key)
count = _get_insn_count(knl, callables_table, insn.id,
subgroup_size, count_redundant_work,
key.count_granularity)
Expand All @@ -1861,9 +1879,9 @@ def _get_op_map_for_single_kernel(
def get_op_map(
t_unit: TranslationUnit, *, count_redundant_work: bool = False,
count_within_subscripts: bool = True,
subgroup_size: Optional[int] = None,
entrypoint: Optional[str] = None,
within: Any = None):
subgroup_size: int | None = None,
entrypoint: str | None = None,
within: Any = None) -> ToCountMap[GuardedPwQPolynomial]:

"""Count the number of operations in a loopy kernel.
Expand Down Expand Up @@ -1955,7 +1973,7 @@ def get_op_map(

# {{{ subgroup size finding

def _find_subgroup_size_for_knl(knl):
def _find_subgroup_size_for_knl(knl: LoopKernel) -> int | None:
from loopy.target.pyopencl import PyOpenCLTarget
if isinstance(knl.target, PyOpenCLTarget) and knl.target.device is not None:
from pyopencl.characterize import get_simd_group_size
Expand Down Expand Up @@ -2013,9 +2031,9 @@ def _process_subgroup_size(knl, subgroup_size_requested):

def _get_mem_access_map_for_single_kernel(
knl: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
count_redundant_work: bool, subgroup_size: Optional[int],
within: Any) -> ToCountPolynomialMap:
callables_table: ConcreteCallablesTable,
count_redundant_work: bool, subgroup_size: int | None,
within: Any) -> ToCountMap[GuardedPwQPolynomial]:

subgroup_size = _process_subgroup_size(knl, subgroup_size)

Expand All @@ -2025,7 +2043,7 @@ def _get_mem_access_map_for_single_kernel(
subgroup_size=subgroup_size)

access_counter = MemAccessCounter(knl, callables_table, kernel_rec)
access_map = access_counter._new_zero_map()
access_map: ToCountMap[GuardedPwQPolynomial] = access_counter._new_zero_map()

from loopy.kernel.instruction import (
Assignment,
Expand All @@ -2047,6 +2065,7 @@ def _get_mem_access_map_for_single_kernel(
).with_set_attributes(read_write=AccessDirection.WRITE)

for key, val in insn_access_map.count_map.items():
key = cast(MemAccess, key)
count = _get_insn_count(knl, callables_table, insn.id,
subgroup_size, count_redundant_work,
key.count_granularity)
Expand All @@ -2065,9 +2084,9 @@ def _get_mem_access_map_for_single_kernel(

def get_mem_access_map(
t_unit: TranslationUnit, *, count_redundant_work: bool = False,
subgroup_size: Optional[int] = None,
entrypoint: Optional[str] = None,
within: Any = None) -> ToCountPolynomialMap:
subgroup_size: int | None = None,
entrypoint: str | None = None,
within: Any = None) -> ToCountMap[GuardedPwQPolynomial]:
"""Count the number of memory accesses in a loopy kernel.
:arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be
Expand Down Expand Up @@ -2184,8 +2203,8 @@ def get_mem_access_map(

def _get_synchronization_map_for_single_kernel(
knl: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
subgroup_size: Optional[int] = None):
callables_table: ConcreteCallablesTable,
subgroup_size: int | None = None) -> ToCountMap[GuardedPwQPolynomial]:

knl = lp.get_one_linearized_kernel(knl, callables_table)

Expand All @@ -2203,10 +2222,12 @@ def _get_synchronization_map_for_single_kernel(
subgroup_size=subgroup_size)

sync_counter = CounterBase(knl, callables_table, kernel_rec)
sync_map = sync_counter._new_zero_map()
sync_map: ToCountMap[GuardedPwQPolynomial] = sync_counter._new_zero_map()

iname_list = []

assert knl.linearization is not None

for sched_item in knl.linearization:
if isinstance(sched_item, EnterLoop):
if sched_item.iname: # (if not empty)
Expand Down Expand Up @@ -2246,8 +2267,8 @@ def _get_synchronization_map_for_single_kernel(

def get_synchronization_map(
t_unit: TranslationUnit, *,
subgroup_size: Optional[int] = None,
entrypoint: Optional[str] = None) -> ToCountPolynomialMap:
subgroup_size: int | None = None,
entrypoint: str | None = None) -> ToCountMap[GuardedPwQPolynomial]:
"""Count the number of synchronization events each work-item encounters in
a loopy kernel.
Expand Down Expand Up @@ -2337,7 +2358,7 @@ def _gather_access_footprints_for_single_kernel(

def gather_access_footprints(
t_unit: TranslationUnit, *, ignore_uncountable: bool = False,
entrypoint: Optional[str] = None) -> Mapping[MemAccess, isl.Set]:
entrypoint: str | None = None) -> Mapping[MemAccess, isl.Set]:
"""Return a dictionary mapping ``(var_name, direction)`` to
:class:`islpy.Set` instances capturing which indices of each the array
*var_name* are read/written (where *direction* is either ``read`` or
Expand Down Expand Up @@ -2409,7 +2430,7 @@ def gather_access_footprint_bytes(
# FIXME: Only supporting a single kernel for now
kernel = t_unit.default_entrypoint

result = {}
result: dict[Countable, GuardedPwQPolynomial] = {}
for ma, var_fp in fp.items():
assert ma.variable
var_descr = kernel.get_var_descriptor(ma.variable)
Expand Down
Loading

0 comments on commit a693fb0

Please sign in to comment.