Skip to content

Commit

Permalink
ExpressionT
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jan 13, 2025
1 parent c67eea6 commit 9dce573
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from loopy.kernel.array import ArrayBase
from loopy.kernel.instruction import InstructionBase
from loopy.types import LoopyType
from loopy.typing import Expression, ExpressionT, auto
from loopy.typing import Expression, auto


__doc__ = """
Expand Down Expand Up @@ -1259,8 +1259,8 @@ def map_floor_div(self, expr):
# {{{ _get_lid_and_gid_strides

def _get_lid_and_gid_strides(
knl: LoopKernel, array: ArrayBase, index: tuple[ExpressionT, ...]
) -> tuple[Mapping[int, ExpressionT], Mapping[int, ExpressionT]]:
knl: LoopKernel, array: ArrayBase, index: tuple[Expression, ...]
) -> tuple[Mapping[int, Expression], Mapping[int, Expression]]:
# find all local and global index tags and corresponding inames
from loopy.symbolic import get_dependencies
my_inames = get_dependencies(index) & knl.all_inames()
Expand Down Expand Up @@ -1377,7 +1377,7 @@ def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap:
def count_var_access(self,
dtype: LoopyType,
name: str,
index: ExpressionT | None,
index: Expression | None,
tags: frozenset[Tag],
var_tags: frozenset[Tag] = frozenset()
) -> ToCountPolynomialMap:
Expand Down
6 changes: 3 additions & 3 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
LoopyError,
UnableToDetermineAccessRangeError,
)
from loopy.typing import Expression, ExpressionT, not_none
from loopy.typing import Expression, not_none


if TYPE_CHECKING:
Expand Down Expand Up @@ -159,7 +159,7 @@

# {{{ mappers with support for loopy-specific primitives

class IdentityMapperMixin(Mapper[ExpressionT, P]):
class IdentityMapperMixin(Mapper[Expression, P]):
def map_tagged_expression(self, expr: TaggedExpression, *args, **kwargs):
new_expr = self.rec(expr.expr, *args, **kwargs)
return TaggedExpression(expr.tags, new_expr)
Expand Down Expand Up @@ -799,7 +799,7 @@ class TaggedExpression(LoopyExpressionBase):
init_arg_names = ("tags", "expr")

tags: frozenset[Tag]
expr: ExpressionT
expr: Expression


@p.expr_dataclass(init=False)
Expand Down

0 comments on commit 9dce573

Please sign in to comment.