Skip to content

Commit

Permalink
Enable, fix TC rules
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 19, 2024
1 parent 90f4b88 commit e1d8cc9
Show file tree
Hide file tree
Showing 44 changed files with 315 additions and 220 deletions.
14 changes: 0 additions & 14 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@
"pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None),
}

# Some modules need to import things just so that sphinx can resolve symbols in
# type annotations. Often, we do not want these imports (e.g. of PyOpenCL) when
# in normal use (because they would introduce unintended side effects or hard
# dependencies). This flag exists so that these imports only occur during doc
# build. Since sphinx appears to resolve type hints lexically (as it should),
# this needs to be cross-module (since, e.g. an inherited arraycontext
# docstring can be read by sphinx when building meshmode, a dependent package),
# this needs a setting of the same name across all packages involved, that's
# why this name is as global-sounding as it is.
import sys


sys._BUILDING_SPHINX_DOCS = True

nitpicky = True

nitpick_ignore_regex = [
Expand Down
12 changes: 9 additions & 3 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from functools import reduce
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -41,7 +41,6 @@
WriteRaceConditionWarning,
warn_with_kernel,
)
from loopy.kernel import LoopKernel
from loopy.kernel.array import (
ArrayBase,
FixedStrideArrayDimTag,
Expand Down Expand Up @@ -73,6 +72,14 @@
from loopy.typing import not_none


if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

from pymbolic.typing import Expression

from loopy.kernel import LoopKernel


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -216,7 +223,6 @@ def check_separated_array_consistency(kernel: LoopKernel) -> None:
@check_each_kernel
def check_offsets_and_dim_tags(kernel: LoopKernel) -> None:
from pymbolic.primitives import ExpressionNode, Variable
from pymbolic.typing import Expression

from loopy.symbolic import DependencyMapper

Expand Down
22 changes: 7 additions & 15 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"""

import logging
import sys
from dataclasses import dataclass, replace
from typing import (
TYPE_CHECKING,
Expand All @@ -35,10 +34,6 @@

import immutables

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
from loopy.translation_unit import CallablesTable, TranslationUnit


logger = logging.getLogger(__name__)

Expand All @@ -51,24 +46,21 @@
from pytools.persistent_dict import WriteOncePersistentDict

from loopy.diagnostic import LoopyError, warn
from loopy.kernel import LoopKernel
from loopy.kernel.function_interface import CallableKernel
from loopy.symbolic import CombineMapper
from loopy.target import TargetBase
from loopy.tools import LoopyKeyBuilder, caches
from loopy.types import LoopyType
from loopy.typing import Expression
from loopy.version import DATA_MODEL_VERSION


if TYPE_CHECKING:
from loopy.codegen.result import GeneratedProgram
from loopy.codegen.tools import CodegenOperationCacheManager


if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
from loopy.codegen.result import GeneratedProgram
from loopy.codegen.result import CodeGenerationResult, GeneratedProgram
from loopy.codegen.tools import CodegenOperationCacheManager
from loopy.kernel import LoopKernel
from loopy.library.reduction import ReductionOpFunction
from loopy.target import TargetBase
from loopy.translation_unit import CallablesTable, TranslationUnit
from loopy.types import LoopyType
from loopy.typing import Expression


__doc__ = """
Expand Down
8 changes: 6 additions & 2 deletions loopy/codegen/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
"""


from typing import TYPE_CHECKING

import islpy as isl
from islpy import dim_type

from loopy.codegen.tools import CodegenOperationCacheManager
from loopy.kernel import LoopKernel

if TYPE_CHECKING:
from loopy.codegen.tools import CodegenOperationCacheManager
from loopy.kernel import LoopKernel


# {{{ approximate, convex bounds check generator
Expand Down
10 changes: 4 additions & 6 deletions loopy/codegen/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
Sequence,
)

import islpy as isl


if TYPE_CHECKING:
import islpy

from loopy.codegen import CodeGenerationState


Expand All @@ -58,8 +58,6 @@ def process_preambles(preambles: Sequence[tuple[int, str]]) -> Sequence[str]:
__doc__ = """
.. currentmodule:: loopy.codegen.result
.. autoclass:: GeneratedProgram
.. autoclass:: CodeGenerationResult
.. autofunction:: merge_codegen_results
Expand Down Expand Up @@ -121,7 +119,7 @@ class CodeGenerationResult:
"""
host_program: GeneratedProgram | None
device_programs: Sequence[GeneratedProgram]
implemented_domains: Mapping[str, isl.Set]
implemented_domains: Mapping[str, islpy.Set]
host_preambles: Sequence[tuple[str, str]] = ()
device_preambles: Sequence[tuple[str, str]] = ()

Expand Down Expand Up @@ -249,7 +247,7 @@ def merge_codegen_results(
new_device_programs = []
new_device_preambles: list[tuple[str, str]] = []
dev_program_names = set()
implemented_domains: dict[str, isl.Set] = {}
implemented_domains: dict[str, islpy.Set] = {}
codegen_result = None

block_cls = codegen_state.ast_builder.ast_block_class
Expand Down
10 changes: 7 additions & 3 deletions loopy/codegen/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@

from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

from pytools import memoize_method

from loopy.kernel import LoopKernel
from loopy.kernel.data import Iname
from loopy.kernel.instruction import InstructionBase
from loopy.schedule import (
Barrier,
BeginBlockItem,
Expand All @@ -43,6 +42,11 @@
)


if TYPE_CHECKING:
import loopy.kernel.data
from loopy.kernel.instruction import InstructionBase


__doc__ = """
.. autoclass:: KernelProxyForCodegenOperationCacheManager
Expand All @@ -58,7 +62,7 @@ class KernelProxyForCodegenOperationCacheManager:
"""
instructions: list[InstructionBase]
linearization: list[ScheduleItem]
inames: dict[str, Iname]
inames: dict[str, loopy.kernel.data.Iname]

@cached_property
def id_to_insn(self):
Expand Down
10 changes: 7 additions & 3 deletions loopy/frontend/fortran/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@
"""

import re
from collections.abc import Mapping
from sys import intern
from typing import ClassVar
from typing import TYPE_CHECKING, ClassVar

import numpy as np

import pytools.lex
from pymbolic.parser import Parser as ExpressionParserBase

from loopy.frontend.fortran.diagnostic import TranslationError
from loopy.symbolic import LexTable


if TYPE_CHECKING:
from collections.abc import Mapping

from loopy.symbolic import LexTable


_less_than = intern("less_than")
Expand Down
15 changes: 8 additions & 7 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import islpy # to help out Sphinx
import islpy as isl
from islpy import dim_type
from pymbolic import ArithmeticExpression
from pytools import (
UniqueNameGenerator,
generate_unique_names,
Expand All @@ -62,6 +61,7 @@
)
from pytools.tag import Tag, Taggable

import loopy.codegen
import loopy.kernel.data # to help out Sphinx
from loopy.diagnostic import CannotBranchDomainTree, LoopyError, StaticValueFindingError
from loopy.kernel.data import (
Expand All @@ -73,18 +73,19 @@
_ArraySeparationInfo,
filter_iname_tags_by_type,
)
from loopy.kernel.instruction import InstructionBase
from loopy.options import Options
from loopy.schedule import ScheduleItem
from loopy.target import TargetBase
from loopy.tools import update_persistent_hash
from loopy.types import LoopyType, NumpyType
from loopy.typing import Expression, InameStr


if TYPE_CHECKING:
import loopy.codegen # to help out Sphinx
from pymbolic import ArithmeticExpression

from loopy.kernel.function_interface import InKernelCallable
from loopy.kernel.instruction import InstructionBase
from loopy.options import Options
from loopy.schedule import ScheduleItem
from loopy.target import TargetBase
from loopy.typing import Expression, InameStr


# {{{ loop kernel object
Expand Down
12 changes: 4 additions & 8 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"""

import re
import sys
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Expand All @@ -41,7 +40,6 @@
import numpy as np # noqa
from typing_extensions import Self, TypeAlias

from pymbolic import ArithmeticExpression
from pymbolic.primitives import is_arithmetic_expression
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable
Expand All @@ -53,15 +51,13 @@


if TYPE_CHECKING:
from pymbolic import ArithmeticExpression

from loopy.codegen import VectorizationInfo
from loopy.kernel import LoopKernel
from loopy.kernel.data import ArrayArg, TemporaryVariable
from loopy.target import TargetBase

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
from loopy.target import TargetBase


T = TypeVar("T")


Expand Down Expand Up @@ -629,7 +625,7 @@ def _parse_shape_or_strides(
x_tup: tuple[Expression | str, ...] = x_parsed
else:
assert x_parsed is not auto
x_tup = (cast(Expression, x_parsed),)
x_tup = (cast("Expression", x_parsed),)

def parse_arith(x: Expression | str) -> ArithmeticExpression:
if isinstance(x, str):
Expand Down Expand Up @@ -1296,7 +1292,7 @@ def eval_expr_assert_integer_constant(i, expr) -> int:

index = tuple(remaining_index)
# only arguments (not temporaries) may be sep-tagged
ary = cast(ArrayArg,
ary = cast("ArrayArg",
kernel.arg_dict[ary._separation_info.subarray_names[tuple(sep_index)]])

# }}}
Expand Down
20 changes: 13 additions & 7 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from enum import IntEnum
from sys import intern
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Sequence,
Expand All @@ -40,9 +41,7 @@

import numpy # FIXME: imported as numpy to allow sphinx to resolve things
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpression, Variable
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase

Expand All @@ -61,10 +60,17 @@
VarAtomicity,
make_assignment,
)
from loopy.types import LoopyType, ToLoopyTypeConvertible
from loopy.typing import Expression, ShapeType, auto


if TYPE_CHECKING:
from immutables import Map

from pymbolic import ArithmeticExpression, Variable

from loopy.types import LoopyType, ToLoopyTypeConvertible


__doc__ = """
.. autofunction:: filter_iname_tags_by_type
Expand Down Expand Up @@ -110,7 +116,7 @@ def _names_from_expr(expr: Expression | str | None) -> frozenset[str]:
if isinstance(expr, str):
return frozenset({expr})
elif isinstance(expr, ExpressionNode):
return frozenset(cast(Variable, v).name for v in dep_mapper(expr))
return frozenset(cast("Variable", v).name for v in dep_mapper(expr))
elif expr is None:
return frozenset()
elif isinstance(expr, Number):
Expand Down Expand Up @@ -435,7 +441,7 @@ class _ArraySeparationInfo:


class ArrayArg(ArrayBase, KernelArgument):
__doc__ = cast(str, ArrayBase.__doc__) + (
__doc__ = cast("str", ArrayBase.__doc__) + (
"""
.. attribute:: address_space
Expand Down Expand Up @@ -637,7 +643,7 @@ def get_arg_decl(self, ast_builder):
# {{{ temporary variable

class TemporaryVariable(ArrayBase):
__doc__ = cast(str, ArrayBase.__doc__) + """
__doc__ = cast("str", ArrayBase.__doc__) + """
.. autoattribute:: storage_shape
.. autoattribute:: base_indices
.. autoattribute:: address_space
Expand Down Expand Up @@ -814,7 +820,7 @@ def nbytes(self) -> Expression:
raise ValueError("shape is None")
if self.shape is auto:
raise ValueError("shape is auto")
shape = cast(Tuple[ArithmeticExpression], self.shape)
shape = cast("Tuple[ArithmeticExpression]", self.shape)

if self.dtype is None:
raise ValueError("data type is indeterminate")
Expand Down
Loading

0 comments on commit e1d8cc9

Please sign in to comment.