Skip to content

Commit

Permalink
Type infer_arg_descr
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Aug 6, 2024
1 parent 0dc0c8f commit 5c84e95
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
auto,
filter_iname_tags_by_type,
)
from loopy.kernel.function_interface import CallableKernel, ScalarCallable
from loopy.kernel.function_interface import (
ArgDescriptor,
CallableKernel,
ScalarCallable,
)

# from loopy.transform.iname import remove_any_newly_unused_inames
from loopy.kernel.instruction import (
Expand Down Expand Up @@ -655,7 +659,7 @@ def traverse_to_infer_arg_descr(kernel, callables_table):
return descr_inferred_kernel, arg_descr_inf_mapper.clbl_inf_ctx


def infer_arg_descr(program):
def infer_arg_descr(t_unit: TranslationUnit) -> TranslationUnit:
"""
Returns a copy of *program* with the
:attr:`loopy.InKernelCallable.arg_id_to_descr` inferred for all the
Expand All @@ -666,12 +670,12 @@ def infer_arg_descr(program):
from loopy.kernel.function_interface import ArrayArgDescriptor, ValueArgDescriptor
from loopy.translation_unit import make_clbl_inf_ctx, resolve_callables

program = resolve_callables(program)
t_unit = resolve_callables(t_unit)

clbl_inf_ctx = make_clbl_inf_ctx(program.callables_table,
program.entrypoints)
clbl_inf_ctx = make_clbl_inf_ctx(t_unit.callables_table,
t_unit.entrypoints)

for e in program.entrypoints:
for e in t_unit.entrypoints:
def _tuple_or_none(s):
if isinstance(s, tuple):
return s
Expand All @@ -680,8 +684,8 @@ def _tuple_or_none(s):
else:
return s,

arg_id_to_descr = {}
for arg in program[e].args:
arg_id_to_descr: dict[str, ArgDescriptor] = {}
for arg in t_unit[e].args:
if isinstance(arg, ArrayBase):
if arg.shape not in (None, auto):
arg_id_to_descr[arg.name] = ArrayArgDescriptor(
Expand All @@ -691,12 +695,12 @@ def _tuple_or_none(s):
arg_id_to_descr[arg.name] = ValueArgDescriptor()
else:
raise NotImplementedError()
new_callable, clbl_inf_ctx = program.callables_table[e].with_descrs(
new_callable, clbl_inf_ctx = t_unit.callables_table[e].with_descrs(
arg_id_to_descr, clbl_inf_ctx)
clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable,
is_entrypoint=True)

return clbl_inf_ctx.finish_program(program)
return clbl_inf_ctx.finish_program(t_unit)

# }}}

Expand Down

0 comments on commit 5c84e95

Please sign in to comment.