Skip to content

Commit

Permalink
Update the Expansion Mapper for index lambda and move the packer to p…
Browse files Browse the repository at this point in the history
…ack the items in the later axes as opposed to prepending the new axes.
  • Loading branch information
nkoskelo committed Jul 16, 2024
1 parent 95e74ef commit 804ed42
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 67 deletions.
5 changes: 4 additions & 1 deletion arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,10 @@ def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
evt, out_dict = pt_prg(self.queue,
allocator=self.allocator,
**bound_arguments)
evt.wait()
if isinstance(evt, list):
[_evt.wait() for _evt in evt]
else:
evt.wait()
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0

key_to_frozen_subary = {
Expand Down
5 changes: 4 additions & 1 deletion arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
# FIXME Kernels (for now) allocate tons of memory in temporaries. If we
# race too far ahead with enqueuing, there is a distinct risk of
# running out of memory. This mitigates that risk a bit, for now.
evt.wait()
if isinstance(evt, list):
[_evt.wait() for _evt in evt]
else:
evt.wait()

def to_output_template(keys, _):
name_in_program = self.output_id_to_name_in_program[keys]
Expand Down
27 changes: 15 additions & 12 deletions arraycontext/parameter_study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA
assert len(args) == np.prod(newshape)

orig_shape = args[0].shape
out = actx.np.stack(args)
outshape = tuple([newshape] + list(orig_shape))

if len(newshape) > 1:
# Reshape the object
out = out.reshape(outshape)
for i in range(len(newshape)):
out = out.with_tagged_axis(i, [study_name_tag(i, newshape[i])])
out = actx.np.stack(args, axis=args[0].ndim)
outshape = tuple(list(orig_shape) + [newshape] )

#if len(newshape) > 1:
# # Reshape the object
# out = out.reshape(outshape)

for i in range(len(orig_shape), len(outshape)):
out = out.with_tagged_axis(i, [study_name_tag(i - len(orig_shape), newshape[i-len(orig_shape)])])
return out


Expand All @@ -140,6 +141,7 @@ def unpack_parameter_study(data: ArrayT,
ndim: int = len(data.axes)
out: Dict[int, List[ArrayT]] = {}

study_count = 0
for i in range(ndim):
axis_tags = data.axes[i].tags_of_type(study_name_tag)
if axis_tags:
Expand All @@ -150,11 +152,12 @@ def unpack_parameter_study(data: ArrayT,
tmp[i] = j
the_slice: Tuple[slice] = tuple(tmp)
# Needs to be a tuple of slices not list of slices.
if i in out.keys():
out[i].append(data[the_slice])
if study_count in out.keys():
out[study_count].append(data[the_slice])
else:
out[i] = [data[the_slice]]

out[study_count] = [data[the_slice]]
if study_count in out.keys():
study_count += 1
# yield data[the_slice]

return out
100 changes: 71 additions & 29 deletions arraycontext/parameter_study/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@

import numpy as np
import pytato as pt
import loopy as lp
from immutabledict import immutabledict


from pytato.scalar_expr import IdentityMapper
import pymbolic.primitives as prim


from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag

Expand Down Expand Up @@ -129,6 +135,7 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A


def map_stack(self, expr: Stack) -> Array:
# TODO: Fix
return super().map_stack(expr)

def map_concatenate(self, expr: Concatenate) -> Array:
Expand All @@ -149,7 +156,6 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array:
new_array = self.rec(expr.array)
prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
new_array)
breakpoint()
axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i
in range(len(expr.axis_permutation))])
# Include the axes we are adding to the system.
Expand Down Expand Up @@ -199,8 +205,8 @@ def map_placeholder(self, expr: Placeholder) -> Array:
axes=correct_axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

def map_index_lambda(self, expr: IndexLambda) -> Array:
# TODO: Fix
# Update bindings first.
new_bindings: Mapping[str, Array] = { name: self.rec(bnd)
for name, bnd in sorted(expr.bindings.items())}
Expand All @@ -210,43 +216,77 @@ def map_index_lambda(self, expr: IndexLambda) -> Array:
from pytools.obj_array import flat_obj_array

all_axis_tags: Set[Tag] = set()
studies_by_variable: Mapping[str, Mapping[Tag, bool]] = {}
for name, bnd in sorted(new_bindings.items()):
axis_tags_for_bnd: Set[Tag] = set()
studies_by_variable[name] = {}
for i in range(len(bnd.axes)):
axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag))
breakpoint()
for tag in axis_tags_for_bnd:
studies_by_variable[name][tag] = 1
all_axis_tags = all_axis_tags.union(axis_tags_for_bnd)

# Freeze the set now.
all_axis_tags = frozenset(all_axis_tags)


breakpoint()
active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags))
axes: Optional[Tuple[Axis]] = tuple([])
study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {}


count = 0
new_shape: List[int] = [0 for i in
range(len(active_studies) + len(expr.shape))]
new_axes: List[Axis] = [Axis() for i in range(len(new_shape))]
new_shape = expr.shape
new_axes = expr.axes

for study in active_studies:
if isinstance(study, ParameterStudyAxisTag):
# Just defensive programming
study_to_axis_number[type(study)] = count
new_shape[count] = study.axis_size # We are recording the size of each parameter study.
new_axes[count] = new_axes[count].tagged([study])
count += 1
# The active studies are added to the end of the bindings.
study_to_axis_number[study] = len(new_shape)
new_shape = new_shape + (study.axis_size,)
new_axes = new_axes + (Axis(tags=frozenset((study,))),)
# This assumes that the axis only has 1 tag,
# because there should be no dependence across instances.

# Now we need to update the expressions.
scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number)

return IndexLambda(expr=scalar_expr,
bindings=type(expr.bindings)(new_bindings),
shape=new_shape,
var_to_reduction_descr=expr.var_to_reduction_descr,
dtype=expr.dtype,
axes=new_axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

for i in range(len(expr.shape)):
new_shape[count] = expr.shape[i]
count += 1
new_shape: Tuple[int] = tuple(new_shape)

breakpoint()
return super().map_index_lambda(expr)
class ParamAxisExpander(IdentityMapper):

def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, Mapping[ParameterStudyAxisTag, bool]],
study_to_axis_number: Mapping[ParameterStudyAxisTag, int]):
# We know that we are not changing the variable that we are indexing into.
# This is stored in the aggregate member of the class Subscript.

# We only need to modify the indexing which is stored in the index member.
name = expr.aggregate.name
if name in studies_by_variable.keys():
# These are the single instance information.
index = self.rec(expr.index, studies_by_variable,
study_to_axis_number)

new_vars: Tuple[prim.Variable] = tuple([])

for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]):
if key in studies_by_variable[name]:
new_vars = new_vars + (prim.Variable(f"_{study_to_axis_number[key]}"),)

if isinstance(index, tuple):
index = index + new_vars
else:
index = tuple(index) + new_vars
return type(expr)(aggregate=expr.aggregate, index=index)
return expr


# {{{ ParamStudyPytatoPyOpenCLArrayContext

Expand All @@ -267,6 +307,10 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
return ParamStudyLazyPyOpenCLFunctionCaller(self, f)


def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
# Update in a subclass if you want.
return t_unit

# }}}


Expand Down Expand Up @@ -338,18 +382,21 @@ def _as_dict_of_named_arrays(keys, ary):
input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()}
input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()}
myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies
breakpoint()

dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays)

breakpoint()
dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays.

# Use the normal compiler now.

compiled_func = self._dag_to_compiled_func(
pt.make_dict_of_named_arrays(dict_of_named_arrays),

compiled_func = self._dag_to_compiled_func(dict_of_named_arrays,
input_id_to_name_in_program=input_id_to_name_in_program,
output_id_to_name_in_program=output_id_to_name_in_program,
output_template=output_template)

breakpoint()
self.program_cache[arg_id_to_descr] = compiled_func
return compiled_func(arg_id_to_arg)

Expand All @@ -364,20 +411,15 @@ def _cut_if_in_param_study(name, arg) -> Array:
"""
ndim: int = len(arg.shape)
newshape = []
update_tags: set = set()
update_axes = []
for i in range(ndim):
axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag)
if axis_tags:
# We need to remove those tags.
update_tags.add(axis_tags)
else:
if not axis_tags:
update_axes.append(arg.axes[i])
newshape.append(arg.shape[i])

update_tags.update(arg.tags)
update_axes = tuple(update_axes)
update_tags = list(update_tags)[0] # Return just the frozenset.
update_tags: FrozenSet[Tag] = arg.tags
return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes,
tags=update_tags)

Expand Down
35 changes: 11 additions & 24 deletions examples/parameter_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,8 @@
queue = cl.CommandQueue(ctx)
actx = ParamStudyPytatoPyOpenCLArrayContext(queue)


# Eq: z = x + y
# Assumptions: x and y are undergoing independent parameter studies.
base_shape = (15, 5)
def rhs(param1, param2):
return param1 + param2
return pt.transpose(param1)
return pt.roll(param1, shift=3, axis=1)
return param1.reshape(np.prod(base_shape))
return param1 + param2


# Experimental setup
base_shape = (15, 5)
seed = 12345
rng = np.random.default_rng(seed)
x = actx.from_numpy(rng.random(base_shape))
Expand All @@ -39,41 +28,39 @@ def rhs(param1, param2):
y2 = actx.from_numpy(rng.random(base_shape))
y3 = actx.from_numpy(rng.random(base_shape))

# Eq: z = x + y
# Assumptions: x and y are undergoing independent parameter studies.
def rhs(param1, param2):
return param1 + param2

@dataclass(frozen=True)
class ParameterStudyForX(ParameterStudyAxisTag):
pass


@dataclass(frozen=True)
class ParameterStudyForY(ParameterStudyAxisTag):
pass


# Pack a parameter study of 3 instances for both x and y.
# We are assuming these are distinct parameter studies.
packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2)
packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3)
output_x = unpack_parameter_study(packx, ParameterStudyForX)

print(packx)

compiled_rhs = actx.compile(rhs) # Build the function caller

# Builds a trace for a single instance of evaluating the RHS and
# then converts it to a program which takes our multiple instances
# of `x` and `y`.
breakpoint()
# then converts it to a program which takes our multiple instances of `x` and `y`.
output = compiled_rhs(packx, packy)
breakpoint()
output_2 = compiled_rhs(x,y)

assert output.shape == (3, 4, 15, 5) # Distinct parameter studies.
assert output.shape == (15, 5, 3, 4) # Distinct parameter studies.

output_x = unpack_parameter_study(output, ParameterStudyForX)
output_y = unpack_parameter_study(output, ParameterStudyForY)
assert len(output_x) == 1 # Number of parameter studies involving "x"
assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study
# All outputs across every other parameter study.
assert output_x[0][0].shape == (4, 15, 5)
assert output_x[0][0].shape == (15, 5, 4)
assert len(output_y) == 1
assert len(output_y[0]) == 4
assert output_y[0][0].shape == (3, 15, 5)
assert output_y[0][0].shape == (15, 5, 3)

0 comments on commit 804ed42

Please sign in to comment.