Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pymbolic typing #868

Merged
merged 10 commits into from
Nov 7, 2024
11 changes: 11 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,15 @@
# As of 2022-06-22, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],

# Reference not found from "<unknown>"? I'm not even sure where to look.
["py:class", r"Expression"],
]

autodoc_type_aliases = {
"ToLoopyTypeConvertible": "ToLoopyTypeConvertible",
"ExpressionT": "ExpressionT",
"InameStr": "InameStr",
"ShapeType": "ShapeType",
"StridesType": "StridesType",
}
5 changes: 4 additions & 1 deletion loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
)
from loopy.translation_unit import TranslationUnit, for_each_kernel, make_program
from loopy.type_inference import infer_unknown_types
from loopy.types import to_loopy_type
from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible, to_loopy_type
from loopy.typing import auto
from loopy.version import MOST_RECENT_LANGUAGE_VERSION, VERSION

Expand Down Expand Up @@ -248,12 +248,14 @@
"LinearSubscript",
"LoopKernel",
"LoopyError",
"LoopyType",
"LoopyWarning",
"MemAccess",
"MemoryOrdering",
"MemoryScope",
"MultiAssignmentBase",
"NoOpInstruction",
"NumpyType",
"Op",
"OpenCLTarget",
"Optional",
Expand All @@ -270,6 +272,7 @@
"TemporaryVariable",
"ToCountMap",
"ToCountPolynomialMap",
"ToLoopyTypeConvertible",
"TranslationUnit",
"TypeCast",
"UniqueName",
Expand Down
8 changes: 7 additions & 1 deletion loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic.primitives import Variable, is_arithmetic_expression
from pytools import memoize_method

from loopy.diagnostic import (
Expand Down Expand Up @@ -1669,6 +1670,8 @@ def _are_sub_array_refs_equivalent(
if len(sar1.swept_inames) != len(sar2.swept_inames):
return False

assert isinstance(sar1.subscript.aggregate, Variable)
assert isinstance(sar2.subscript.aggregate, Variable)
if sar1.subscript.aggregate.name != sar2.subscript.aggregate.name:
return False

Expand All @@ -1692,7 +1695,10 @@ def _are_sub_array_refs_equivalent(

for idx1, idx2 in zip(sar1.subscript.index_tuple,
sar2.subscript.index_tuple):
if simplify_via_aff(subst_mapper(idx1) - idx2) != 0:
subst_idx1 = subst_mapper(idx1)
assert is_arithmetic_expression(subst_idx1)
assert is_arithmetic_expression(idx2)
if simplify_via_aff(subst_idx1 - idx2) != 0:
return False
return True

Expand Down
16 changes: 15 additions & 1 deletion loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from immutables import Map

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
from loopy.translation_unit import CallablesTable, TranslationUnit


Expand Down Expand Up @@ -86,6 +87,12 @@
.. automodule:: loopy.codegen.result

.. automodule:: loopy.codegen.tools

References
^^^^^^^^^^
.. class:: Expression

See :class:`pymbolic.Expression`.
"""


Expand Down Expand Up @@ -661,8 +668,15 @@ def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult:
ast=t_unit.target.get_device_ast_builder().ast_module.Collection(
callee_fdecls+[device_programs[0].ast]))] +
device_programs[1:])

def not_reduction_op(name: str | ReductionOpFunction) -> str:
assert isinstance(name, str)
return name

cgr = TranslationUnitCodeGenerationResult(
host_programs=host_programs,
host_programs={
not_reduction_op(name): prg
for name, prg in host_programs.items()},
device_programs=device_programs,
device_preambles=device_preambles)

Expand Down
3 changes: 2 additions & 1 deletion loopy/frontend/fortran/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class FortranExpressionParser(ExpressionParserBase):
(_not, pytools.lex.RE(r"\.not\.", re.I)),
(_and, pytools.lex.RE(r"\.and\.", re.I)),
(_or, pytools.lex.RE(r"\.or\.", re.I)),
] + ExpressionParserBase.lex_table
*ExpressionParserBase.lex_table,
]

def __init__(self, tree_walker):
self.tree_walker = tree_walker
Expand Down
5 changes: 4 additions & 1 deletion loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic import ArithmeticExpressionT
from pytools import (
UniqueNameGenerator,
generate_unique_names,
Expand Down Expand Up @@ -1042,7 +1043,9 @@ def get_grid_size_upper_bounds(self, callables_table, ignore_auto=False,
def get_grid_size_upper_bounds_as_exprs(
self, callables_table,
ignore_auto=False, return_dict=False
) -> Tuple[Tuple[ExpressionT, ...], Tuple[ExpressionT, ...]]:
) -> Tuple[
Tuple[ArithmeticExpressionT, ...],
Tuple[ArithmeticExpressionT, ...]]:
"""Return a tuple (global_size, local_size) containing a grid that
could accommodate execution of *all* instructions in the kernel.

Expand Down
55 changes: 32 additions & 23 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from loopy.symbolic import flatten


__copyright__ = "Copyright (C) 2012 Andreas Kloeckner"

Expand Down Expand Up @@ -47,13 +45,15 @@
import numpy as np # noqa
from typing_extensions import TypeAlias

from pymbolic import ArithmeticExpressionT
from pymbolic.primitives import is_arithmetic_expression
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable

from loopy.diagnostic import LoopyError
from loopy.tools import is_integer
from loopy.symbolic import flatten
from loopy.types import LoopyType
from loopy.typing import ExpressionT, ShapeType, auto
from loopy.typing import ExpressionT, ShapeType, auto, is_integer


if TYPE_CHECKING:
Expand Down Expand Up @@ -625,17 +625,35 @@ def _parse_shape_or_strides(
if x is auto:
return auto

if isinstance(x, str):
x = parse(x)
if not isinstance(x, str):
x_parsed = x
else:
x_parsed = parse(x)

if isinstance(x, list):
if isinstance(x_parsed, list):
raise ValueError("shape can't be a list")

if not isinstance(x, tuple):
assert x is not auto
x = (x,)
if isinstance(x_parsed, tuple):
x_tup: tuple[ExpressionT | str, ...] = x_parsed
else:
assert x_parsed is not auto
x_tup = (cast(ExpressionT, x_parsed),)

def parse_arith(x: ExpressionT | str) -> ArithmeticExpressionT:
if isinstance(x, str):
res = parse(x)
else:
res = x

# The Fortran parser may do this, but this is (deliberately) outside
# the behavior allowed by types, because the hope is to phase it out.
if x is None:
return x

assert is_arithmetic_expression(res)
return res

return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x)
return tuple(parse_arith(xi) for xi in x_tup)


class ArrayBase(ImmutableRecord, Taggable):
Expand Down Expand Up @@ -1026,16 +1044,6 @@ def __str__(self):
def __repr__(self):
return "<%s>" % self.__str__()

def update_persistent_hash_for_shape(self, key_hash, key_builder, shape):
if isinstance(shape, tuple):
for shape_i in shape:
if shape_i is None:
key_builder.rec(key_hash, shape_i)
else:
key_builder.update_for_pymbolic_expression(key_hash, shape_i)
else:
key_builder.rec(key_hash, shape)

def update_persistent_hash(self, key_hash, key_builder):
"""Custom hash computation function for use with
:class:`pytools.persistent_dict.PersistentDict`.
Expand All @@ -1044,7 +1052,7 @@ def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.dtype)
self.update_persistent_hash_for_shape(key_hash, key_builder, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.dim_tags)
key_builder.rec(key_hash, self.offset)
key_builder.rec(key_hash, self.dim_names)
Expand Down Expand Up @@ -1232,11 +1240,12 @@ def get_access_info(kernel: "LoopKernel",

import loopy as lp

def eval_expr_assert_integer_constant(i, expr):
def eval_expr_assert_integer_constant(i, expr) -> int:
from pymbolic.mapper.evaluator import UnknownVariableError
try:
result = eval_expr(expr)
except UnknownVariableError as e:
assert ary.dim_tags is not None
raise LoopyError("When trying to index the array '%s' along axis "
"%d (tagged '%s'), the index was not a compile-time "
"constant (but it has to be in order for code to be "
Expand Down
15 changes: 11 additions & 4 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpressionT
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase

Expand Down Expand Up @@ -87,6 +88,13 @@
.. autoclass:: UnrollTag

.. autoclass:: Iname

References
^^^^^^^^^^

.. class:: ToLoopyTypeConvertible

See :class:`loopy.ToLoopyTypeConvertible`.
"""

# This docstring is included in ref_internals. Do not include parts of the public
Expand Down Expand Up @@ -809,7 +817,7 @@ def nbytes(self) -> ExpressionT:
raise ValueError("shape is None")
if self.shape is auto:
raise ValueError("shape is auto")
shape = cast(Tuple[ExpressionT], self.shape)
shape = cast(Tuple[ArithmeticExpressionT], self.shape)

if self.dtype is None:
raise ValueError("data type is indeterminate")
Expand Down Expand Up @@ -853,8 +861,7 @@ def update_persistent_hash(self, key_hash, key_builder):
"""

super().update_persistent_hash(key_hash, key_builder)
self.update_persistent_hash_for_shape(key_hash, key_builder,
self.storage_shape)
key_builder.rec(key_hash, self.storage_shape)
key_builder.rec(key_hash, self.base_indices)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.base_storage)
Expand Down Expand Up @@ -899,7 +906,7 @@ def copy(self, **kwargs: Any) -> SubstitutionRule:
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.arguments)
key_builder.update_for_pymbolic_expression(key_hash, self.expression)
key_builder.rec(key_hash, self.expression)


# }}}
Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def depends_on(self):
return frozenset(var.name for var in result)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.update_for_pymbolic_expression(key_hash, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.dim_tags)

Expand Down
Loading
Loading