Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jan 10, 2025
1 parent da547b6 commit c67eea6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
56 changes: 34 additions & 22 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,30 @@
from dataclasses import dataclass, replace
from enum import Enum, auto as enum_auto
from functools import cached_property, partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Iterable,
TypeVar,
Union,
cast,
)

from immutabledict import immutabledict
from typing import TYPE_CHECKING, ClassVar

import islpy as isl
import pymbolic.primitives as p
from islpy import PwQPolynomial, dim_type
from pymbolic.mapper import CombineMapper
from pymbolic.typing import ArithmeticExpressionT
from pytools import memoize_method
from pytools.tag import Tag

import loopy as lp
from loopy.diagnostic import LoopyError, warn_with_kernel
from loopy.kernel import LoopKernel
from loopy.kernel.array import ArrayBase
from loopy.kernel.data import AddressSpace, MultiAssignmentBase
from loopy.kernel.function_interface import CallableKernel
from loopy.kernel.instruction import InstructionBase
from loopy.symbolic import (
CoefficientCollector,
Reduction,
Expand All @@ -58,12 +63,19 @@
flatten,
)
from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit
from loopy.types import LoopyType
from loopy.typing import Expression, ExpressionT, auto


if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence

import pymbolic.primitives as p
from pymbolic.typing import ArithmeticExpressionT
from pytools.tag import Tag

from loopy.kernel.array import ArrayBase
from loopy.kernel.instruction import InstructionBase
from loopy.types import LoopyType
from loopy.typing import Expression, ExpressionT, auto


__doc__ = """
Expand Down Expand Up @@ -245,7 +257,7 @@ def __add__(self, other: ToCountMap[CountT]) -> ToCountMap[CountT]:
result[k] = self.count_map.get(k, 0) + v
return self.copy(count_map=result)

def __radd__(self, other: Union[int, ToCountMap[CountT]]) -> ToCountMap[CountT]:
def __radd__(self, other: int | ToCountMap[CountT]) -> ToCountMap[CountT]:
if other != 0:
raise ValueError("ToCountMap: Attempted to add ToCountMap "
"to {} {}. ToCountMap may only be added to "
Expand Down Expand Up @@ -487,7 +499,7 @@ def to_bytes(self) -> ToCountMap[CountT]:
new_count_map = {}

for key, val in self.count_map.items():
new_count_map[key] = int(key.dtype.itemsize) * val # type: ignore[union-attr] # noqa: E501
new_count_map[key] = int(key.dtype.itemsize) * val # type: ignore[union-attr]

return self.copy(new_count_map)

Expand Down Expand Up @@ -821,7 +833,7 @@ class MemAccess:
A :class:`frozenset` of tags to the operation.
"""

address_space: AddressSpace | Type[auto] | None = None
address_space: AddressSpace | type[auto] | None = None
dtype: LoopyType | None = None
lid_strides: Mapping[int, Expression] | None = None
gid_strides: Mapping[int, Expression] | None = None
Expand Down Expand Up @@ -1127,7 +1139,7 @@ def map_product(
kernel_name=self.knl.name): self.one})
+ self.rec(child, tags)
for child in expr.children
if not is_zero(cast(ArithmeticExpressionT, child) + 1)) + \
if not is_zero(cast("ArithmeticExpressionT", child) + 1)) + \
self.new_poly_map({Op(dtype=self.type_inf(expr),
op_type=OpType.MUL,
tags=tags,
Expand Down Expand Up @@ -1159,7 +1171,7 @@ def map_power(self, expr: p.Power, tags: frozenset[Tag]) -> ToCountPolynomialMap
+ self.rec(expr.exponent, tags)

def map_left_shift(
self, expr: Union[p.LeftShift, p.RightShift], tags: frozenset[Tag]
self, expr: p.LeftShift | p.RightShift, tags: frozenset[Tag]
) -> ToCountPolynomialMap:
return self.new_poly_map({Op(dtype=self.type_inf(expr),
op_type=OpType.SHIFT,
Expand All @@ -1181,7 +1193,7 @@ def map_bitwise_not(
+ self.rec(expr.child, tags)

def map_bitwise_or(
self, expr: Union[p.BitwiseOr, p.BitwiseAnd, p.BitwiseXor],
self, expr: p.BitwiseOr | p.BitwiseAnd | p.BitwiseXor,
tags: frozenset[Tag]) -> ToCountPolynomialMap:
return self.new_poly_map({Op(dtype=self.type_inf(expr),
op_type=OpType.BITWISE,
Expand All @@ -1202,7 +1214,7 @@ def map_if(self, expr: p.If, tags: frozenset[Tag]) -> ToCountPolynomialMap:
+ self.rec(expr.else_, tags)

def map_min(
self, expr: Union[p. Min, p.Max], tags: frozenset[Tag]
self, expr: p.Min | p.Max, tags: frozenset[Tag]
) -> ToCountPolynomialMap:
return self.new_poly_map({Op(dtype=self.type_inf(expr),
op_type=OpType.MAXMIN,
Expand Down Expand Up @@ -1847,7 +1859,7 @@ def _get_op_map_for_single_kernel(
if isinstance(insn, (CallInstruction, Assignment)):
ops = op_counter(insn.assignees) + op_counter(insn.expression)
for key, val in ops.count_map.items():
key = cast(Op, key)
key = cast("Op", key)
count = _get_insn_count(knl, callables_table, insn.id,
subgroup_size, count_redundant_work,
key.count_granularity)
Expand Down Expand Up @@ -1931,7 +1943,7 @@ def get_op_map(
if len(t_unit.entrypoints) > 1:
raise LoopyError("Must provide entrypoint")

entrypoint = next(iter(program.entrypoints))
entrypoint = next(iter(t_unit.entrypoints))

assert entrypoint in t_unit.entrypoints

Expand Down Expand Up @@ -2052,7 +2064,7 @@ def _get_mem_access_map_for_single_kernel(
).with_set_attributes(read_write=AccessDirection.WRITE)

for key, val in insn_access_map.count_map.items():
key = cast(MemAccess, key)
key = cast("MemAccess", key)
count = _get_insn_count(knl, callables_table, insn.id,
subgroup_size, count_redundant_work,
key.count_granularity)
Expand Down Expand Up @@ -2162,7 +2174,7 @@ def get_mem_access_map(
if len(t_unit.entrypoints) > 1:
raise LoopyError("Must provide entrypoint")

entrypoint = next(iter(program.entrypoints))
entrypoint = next(iter(t_unit.entrypoints))

assert entrypoint in t_unit.entrypoints

Expand Down Expand Up @@ -2295,7 +2307,7 @@ def get_synchronization_map(
if len(t_unit.entrypoints) > 1:
raise LoopyError("Must provide entrypoint")

entrypoint = next(iter(program.entrypoints))
entrypoint = next(iter(t_unit.entrypoints))

assert entrypoint in t_unit.entrypoints
from loopy.preprocess import infer_unknown_types, preprocess_program
Expand Down Expand Up @@ -2360,7 +2372,7 @@ def gather_access_footprints(
if len(t_unit.entrypoints) > 1:
raise LoopyError("Must provide entrypoint")

entrypoint = next(iter(program.entrypoints))
entrypoint = next(iter(t_unit.entrypoints))

assert entrypoint in t_unit.entrypoints

Expand Down
6 changes: 3 additions & 3 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
LoopyError,
UnableToDetermineAccessRangeError,
)
from loopy.typing import Expression, not_none
from loopy.typing import Expression, ExpressionT, not_none


if TYPE_CHECKING:
Expand Down Expand Up @@ -283,7 +283,7 @@ def map_tagged_expression(self, expr, *args, **kwargs):
return

self.rec(expr.expr, *args, **kwargs)

def map_literal(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
self.visit(expr, *args, **kwargs)

Expand Down Expand Up @@ -363,7 +363,7 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):
class CombineMapper(CombineMapperBase[ResultT, P]):
def map_tagged_expression(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

def map_reduction(self, expr, *args: P.args, **kwargs: P.kwargs):
return self.rec(expr.expr, *args, **kwargs)

Expand Down

0 comments on commit c67eea6

Please sign in to comment.