Skip to content

Commit

Permalink
Merge branch 'main' into better_loop_around_nest_map
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Aug 16, 2021
2 parents 43edda9 + 7c90b7e commit e8b03c8
Show file tree
Hide file tree
Showing 12 changed files with 450 additions and 41 deletions.
1 change: 1 addition & 0 deletions doc/ref_other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ following always works::

.. autofunction:: show_dependency_graph

.. autofunction:: t_unit_to_python
6 changes: 4 additions & 2 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
get_global_barrier_order,
find_most_recent_global_barrier,
get_subkernels,
get_subkernel_to_insn_id_map)
get_subkernel_to_insn_id_map,
)
from loopy.types import to_loopy_type
from loopy.kernel.creation import make_kernel, UniqueName, make_function
from loopy.library.reduction import register_reduction_parser
Expand Down Expand Up @@ -152,7 +153,7 @@
from loopy.target.ispc import ISPCTarget
from loopy.target.numba import NumbaTarget, NumbaCudaTarget

from loopy.tools import Optional
from loopy.tools import Optional, t_unit_to_python


__all__ = [
Expand Down Expand Up @@ -255,6 +256,7 @@
"find_most_recent_global_barrier",
"get_subkernels",
"get_subkernel_to_insn_id_map",
"t_unit_to_python",

"to_loopy_type",

Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def is_type_specialized(self):

class ScalarCallable(InKernelCallable):
"""
An abstract interface the to a scalar callable encountered in a kernel.
An abstract interface to a scalar callable encountered in a kernel.
.. attribute:: name_in_target
Expand Down
20 changes: 20 additions & 0 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,24 @@ def __call__(self, dtype, operand1, operand2, callables_table, target):
return operand1 * operand2, callables_table


class AnyReductionOperation(ScalarReductionOperation):
def neutral_element(self, dtype, callables_table, target):
return False, callables_table

def __call__(self, dtype, operand1, operand2, callables_table, target):
from pymbolic.primitives import LogicalOr
return LogicalOr((operand1, operand2)), callables_table


class AllReductionOperation(ScalarReductionOperation):
def neutral_element(self, dtype, callables_table, target):
return True, callables_table

def __call__(self, dtype, operand1, operand2, callables_table, target):
from pymbolic.primitives import LogicalAnd
return LogicalAnd((operand1, operand2)), callables_table


def get_le_neutral(dtype):
"""Return a number y that satisfies (x <= y) for all y."""

Expand Down Expand Up @@ -489,6 +507,8 @@ class ArgMinReductionOperation(_ArgExtremumReductionOperation):
"product": ProductReductionOperation,
"max": MaxReductionOperation,
"min": MinReductionOperation,
"any": AnyReductionOperation,
"all": AllReductionOperation,
"argmax": ArgMaxReductionOperation,
"argmin": ArgMinReductionOperation,
"segmented(sum)": SegmentedSumReductionOperation,
Expand Down
34 changes: 34 additions & 0 deletions loopy/schedule/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,40 @@ def add_extra_args_to_schedule(kernel):
# }}}


# {{{ get_return_from_kernel_mapping

def get_return_from_kernel_mapping(kernel):
"""
Returns a mapping from schedule index of every schedule item (S) in
*kernel* to the schedule index of :class:`loopy.schedule.ReturnFromKernel`
of the active sub-kernel at 'S'.
"""
from loopy.kernel import LoopKernel
from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop,
CallKernel, ReturnFromKernel, Barrier)
assert isinstance(kernel, LoopKernel)
assert isinstance(kernel.linearization, list)
return_from_kernel_idxs = {}
current_return_from_kernel = None
for sched_idx, sched_item in list(enumerate(kernel.linearization))[::-1]:
if isinstance(sched_item, CallKernel):
return_from_kernel_idxs[sched_idx] = current_return_from_kernel
current_return_from_kernel = None
elif isinstance(sched_item, ReturnFromKernel):
assert current_return_from_kernel is None
current_return_from_kernel = sched_idx
return_from_kernel_idxs[sched_idx] = current_return_from_kernel
elif isinstance(sched_item, (RunInstruction, EnterLoop, LeaveLoop,
Barrier)):
return_from_kernel_idxs[sched_idx] = current_return_from_kernel
else:
raise NotImplementedError(type(sched_item))

return return_from_kernel_idxs

# }}}


def _pull_out_loop_nest(tree, loop_nests, inames_to_pull_out):
"""
Returns a copy of *tree* that realizes *inames_to_pull_out* as loop
Expand Down
17 changes: 14 additions & 3 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ def map_variable(self, expr):
array = self.knl.arg_dict[name]
else:
# this is a temporary variable
# FIXME temporary variable could have global address space
return self.new_zero_poly_map()

if not isinstance(array, lp.ArrayArg):
Expand All @@ -1316,14 +1317,24 @@ def map_subscript(self, expr):
except AttributeError:
var_tags = frozenset()

is_global_temp = False
if name in self.knl.arg_dict:
array = self.knl.arg_dict[name]
elif name in self.knl.temporary_variables:
# This a temporary, but might have global address space
from loopy.kernel.data import AddressSpace
array = self.knl.temporary_variables[name]
if array.address_space != AddressSpace.GLOBAL:
# This temporary does not have global address space
return self.rec(expr.index)
# This temporary has global address space
is_global_temp = True
else:
# this is a temporary variable
# This temporary does not have global address space
return self.rec(expr.index)

if not isinstance(array, lp.ArrayArg):
# this array is not in global memory
if (not is_global_temp) and not isinstance(array, lp.ArrayArg):
# This array is not in global memory
return self.rec(expr.index)

index_tuple = expr.index # could be tuple or scalar index
Expand Down
5 changes: 5 additions & 0 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,11 @@ def tag(self):
def __getinitargs__(self):
return self.name, self.tags

def copy(self, *, name=None, tags=None):
name = self.name if name is None else name
tags = self.tags if tags is None else tags
return TaggedVariable(name, tags)

mapper_method = intern("map_tagged_variable")


Expand Down
Loading

0 comments on commit e8b03c8

Please sign in to comment.