Skip to content

Commit

Permalink
Enable, fix UP rules
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 19, 2024
1 parent 79b9a0a commit 5005f91
Show file tree
Hide file tree
Showing 38 changed files with 342 additions and 389 deletions.
28 changes: 14 additions & 14 deletions loopy/auto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -80,26 +80,26 @@ def fill_rand(ary):
@dataclass
class TestArgInfo:
name: str
ref_array: "cla.Array"
ref_storage_array: "cla.Array"
ref_array: cla.Array
ref_storage_array: cla.Array

ref_pre_run_array: "cla.Array"
ref_pre_run_storage_array: "cla.Array"
ref_pre_run_array: cla.Array
ref_pre_run_storage_array: cla.Array

ref_shape: Tuple[int, ...]
ref_strides: Tuple[int, ...]
ref_shape: tuple[int, ...]
ref_strides: tuple[int, ...]
ref_alloc_size: int
ref_numpy_strides: Tuple[int, ...]
ref_numpy_strides: tuple[int, ...]
needs_checking: bool

# The attributes below are being modified in make_args, hence this dataclass
# cannot be frozen.
test_storage_array: Optional["cla.Array"] = None
test_array: Optional["cla.Array"] = None
test_shape: Optional[Tuple[int, ...]] = None
test_strides: Optional[Tuple[int, ...]] = None
test_numpy_strides: Optional[Tuple[int, ...]] = None
test_alloc_size: Optional[Tuple[int, ...]] = None
test_storage_array: cla.Array | None = None
test_array: cla.Array | None = None
test_shape: tuple[int, ...] | None = None
test_strides: tuple[int, ...] | None = None
test_numpy_strides: tuple[int, ...] | None = None
test_alloc_size: tuple[int, ...] | None = None


# {{{ "reference" arguments
Expand Down
7 changes: 3 additions & 4 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from collections import defaultdict
from collections.abc import Mapping, Sequence
from functools import reduce
from typing import List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -225,7 +224,7 @@ def check_offsets_and_dim_tags(kernel: LoopKernel) -> None:
dep_mapper: DependencyMapper[[]] = DependencyMapper()

def ensure_depends_only_on_arguments(
what: str, expr: Union[str, Expression]) -> None:
what: str, expr: str | Expression) -> None:
if isinstance(expr, str):
expr = Variable(expr)

Expand All @@ -252,7 +251,7 @@ def ensure_depends_only_on_arguments(
raise LoopyError(f"invalid value of offset for '{arg.name}'")

if arg.dim_tags is None:
new_dim_tags: Optional[Tuple[ArrayDimImplementationTag, ...]] = \
new_dim_tags: tuple[ArrayDimImplementationTag, ...] | None = \
arg.dim_tags
else:
new_dim_tags = ()
Expand Down Expand Up @@ -1327,7 +1326,7 @@ def check_for_nested_base_storage(kernel: LoopKernel) -> None:
# must run after preprocessing has created variables for base_storage

from loopy.kernel.data import ArrayArg
arrays: List[ArrayBase] = [
arrays: list[ArrayBase] = [
arg for arg in kernel.args if isinstance(arg, ArrayArg)
]
arrays = arrays + list(kernel.temporary_variables.values())
Expand Down
41 changes: 18 additions & 23 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,8 @@
from typing import (
TYPE_CHECKING,
Any,
FrozenSet,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)

from immutables import Map
Expand Down Expand Up @@ -137,8 +132,8 @@ class SeenFunction:
"""
name: str
c_name: str
arg_dtypes: Tuple[LoopyType, ...]
result_dtypes: Tuple[LoopyType, ...]
arg_dtypes: tuple[LoopyType, ...]
result_dtypes: tuple[LoopyType, ...]


@dataclass(frozen=True)
Expand Down Expand Up @@ -203,12 +198,12 @@ class CodeGenerationState:
kernel: LoopKernel
target: TargetBase
implemented_domain: isl.Set
implemented_predicates: FrozenSet[Union[str, Expression]]
implemented_predicates: frozenset[str | Expression]

# /!\ mutable
seen_dtypes: Set[LoopyType]
seen_functions: Set[SeenFunction]
seen_atomic_dtypes: Set[LoopyType]
seen_dtypes: set[LoopyType]
seen_functions: set[SeenFunction]
seen_atomic_dtypes: set[LoopyType]

var_subst_map: Map[str, Expression]
allow_complex: bool
Expand All @@ -218,8 +213,8 @@ class CodeGenerationState:
is_generating_device_code: bool
gen_program_name: str
schedule_index_end: int
codegen_cachemanager: "CodegenOperationCacheManager"
vectorization_info: Optional[VectorizationInfo] = None
codegen_cachemanager: CodegenOperationCacheManager
vectorization_info: VectorizationInfo | None = None

def __post_init__(self):
# FIXME: If this doesn't bomb during testing, we can get rid of target.
Expand All @@ -230,15 +225,15 @@ def __post_init__(self):

# {{{ copy helpers

def copy(self, **kwargs: Any) -> "CodeGenerationState":
def copy(self, **kwargs: Any) -> CodeGenerationState:
return replace(self, **kwargs)

def copy_and_assign(
self, name: str, value: Expression) -> "CodeGenerationState":
self, name: str, value: Expression) -> CodeGenerationState:
"""Make a copy of self with variable *name* fixed to *value*."""
return self.copy(var_subst_map=self.var_subst_map.set(name, value))

def copy_and_assign_many(self, assignments) -> "CodeGenerationState":
def copy_and_assign_many(self, assignments) -> CodeGenerationState:
"""Make a copy of self with *assignments* included."""

return self.copy(var_subst_map=self.var_subst_map.update(assignments))
Expand Down Expand Up @@ -371,9 +366,9 @@ def map_constant(self, expr):
@dataclass(frozen=True)
class PreambleInfo:
kernel: LoopKernel
seen_dtypes: Set[LoopyType]
seen_functions: Set[SeenFunction]
seen_atomic_dtypes: Set[LoopyType]
seen_dtypes: set[LoopyType]
seen_functions: set[SeenFunction]
seen_atomic_dtypes: set[LoopyType]

# FIXME: This makes all the above redundant. It probably shouldn't be here.
codegen_state: CodeGenerationState
Expand Down Expand Up @@ -546,10 +541,10 @@ class TranslationUnitCodeGenerationResult:
.. automethod:: all_code
"""
host_programs: Mapping[str, "GeneratedProgram"]
device_programs: Sequence["GeneratedProgram"]
host_preambles: Sequence[Tuple[int, str]] = ()
device_preambles: Sequence[Tuple[int, str]] = ()
host_programs: Mapping[str, GeneratedProgram]
device_programs: Sequence[GeneratedProgram]
host_preambles: Sequence[tuple[int, str]] = ()
device_preambles: Sequence[tuple[int, str]] = ()

def host_code(self):
from loopy.codegen.result import process_preambles
Expand Down
4 changes: 1 addition & 3 deletions loopy/codegen/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"""


from typing import FrozenSet

import islpy as isl
from islpy import dim_type

Expand Down Expand Up @@ -65,7 +63,7 @@ def get_approximate_convex_bounds_checks(domain, check_inames,

def get_usable_inames_for_conditional(
kernel: LoopKernel, sched_index: int,
op_cache_manager: CodegenOperationCacheManager) -> FrozenSet[str]:
op_cache_manager: CodegenOperationCacheManager) -> frozenset[str]:
active_inames = op_cache_manager.active_inames[sched_index]
crosses_barrier = op_cache_manager.has_barrier_within[sched_index]

Expand Down
29 changes: 12 additions & 17 deletions loopy/codegen/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,8 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import islpy as isl
Expand All @@ -43,7 +38,7 @@
from loopy.codegen import CodeGenerationState


def process_preambles(preambles: Sequence[Tuple[int, str]]) -> Sequence[str]:
def process_preambles(preambles: Sequence[tuple[int, str]]) -> Sequence[str]:
seen_preamble_tags = set()
dedup_preambles = []

Expand Down Expand Up @@ -97,9 +92,9 @@ class GeneratedProgram:
name: str
is_device_program: bool
ast: Any
body_ast: Optional[Any] = None
body_ast: Any | None = None

def copy(self, **kwargs: Any) -> "GeneratedProgram":
def copy(self, **kwargs: Any) -> GeneratedProgram:
return replace(self, **kwargs)


Expand All @@ -124,13 +119,13 @@ class CodeGenerationResult:
.. automethod:: device_code
.. automethod:: all_code
"""
host_program: Optional[GeneratedProgram]
host_program: GeneratedProgram | None
device_programs: Sequence[GeneratedProgram]
implemented_domains: Mapping[str, isl.Set]
host_preambles: Sequence[Tuple[str, str]] = ()
device_preambles: Sequence[Tuple[str, str]] = ()
host_preambles: Sequence[tuple[str, str]] = ()
device_preambles: Sequence[tuple[str, str]] = ()

def copy(self, **kwargs: Any) -> "CodeGenerationResult":
def copy(self, **kwargs: Any) -> CodeGenerationResult:
return replace(self, **kwargs)

@staticmethod
Expand Down Expand Up @@ -188,7 +183,7 @@ def all_code(self):
+ str(self.host_program.ast))

def current_program(
self, codegen_state: "CodeGenerationState") -> GeneratedProgram:
self, codegen_state: CodeGenerationState) -> GeneratedProgram:
if codegen_state.is_generating_device_code:
if self.device_programs:
result = self.device_programs[-1]
Expand Down Expand Up @@ -234,8 +229,8 @@ def with_new_ast(self, codegen_state, new_ast):
# {{{ support code for AST merging

def merge_codegen_results(
codegen_state: "CodeGenerationState",
elements: Sequence[Union[CodeGenerationResult, Any]], collapse=True
codegen_state: CodeGenerationState,
elements: Sequence[CodeGenerationResult | Any], collapse=True
) -> CodeGenerationResult:
elements = [el for el in elements if el is not None]

Expand All @@ -252,9 +247,9 @@ def merge_codegen_results(

ast_els = []
new_device_programs = []
new_device_preambles: List[Tuple[str, str]] = []
new_device_preambles: list[tuple[str, str]] = []
dev_program_names = set()
implemented_domains: Dict[str, isl.Set] = {}
implemented_domains: dict[str, isl.Set] = {}
codegen_result = None

block_cls = codegen_state.ast_builder.ast_block_class
Expand Down
9 changes: 4 additions & 5 deletions loopy/codegen/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from dataclasses import dataclass
from functools import cached_property
from typing import Dict, FrozenSet, List

from pytools import memoize_method

Expand Down Expand Up @@ -57,9 +56,9 @@ class KernelProxyForCodegenOperationCacheManager:
Proxy to :class:`loopy.LoopKernel` to be used by
:class:`CodegenOperationCacheManager`.
"""
instructions: List[InstructionBase]
linearization: List[ScheduleItem]
inames: Dict[str, Iname]
instructions: list[InstructionBase]
linearization: list[ScheduleItem]
inames: dict[str, Iname]

@cached_property
def id_to_insn(self):
Expand Down Expand Up @@ -209,7 +208,7 @@ def get_insn_ids_for_block_at(self, sched_index):

@memoize_method
def get_concurrent_inames_in_a_callkernel(
self, callkernel_index: int) -> FrozenSet[str]:
self, callkernel_index: int) -> frozenset[str]:
"""
Returns a :class:`frozenset` of concurrent inames in a callkernel
Expand Down
Loading

0 comments on commit 5005f91

Please sign in to comment.