From 804ed42cd1a69558bb6c0c2ae0459e3a8d7c18d8 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 16 Jul 2024 16:26:17 -0500 Subject: [PATCH] Update the Expansion Mapper for index lambda and move the packer to pack the items in the later axes as opposed to prepending the new axes. --- arraycontext/impl/pytato/__init__.py | 5 +- arraycontext/impl/pytato/compile.py | 5 +- arraycontext/parameter_study/__init__.py | 27 +++--- arraycontext/parameter_study/transform.py | 100 +++++++++++++++------- examples/parameter_study.py | 35 +++----- 5 files changed, 105 insertions(+), 67 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 5aefbeaa..fac98e8f 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -563,7 +563,10 @@ def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: evt, out_dict = pt_prg(self.queue, allocator=self.allocator, **bound_arguments) - evt.wait() + if isinstance(evt, list): + [_evt.wait() for _evt in evt] + else: + evt.wait() assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 key_to_frozen_subary = { diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 523ef3a4..4d3a14d8 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -646,7 +646,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: # FIXME Kernels (for now) allocate tons of memory in temporaries. If we # race too far ahead with enqueuing, there is a distinct risk of # running out of memory. This mitigates that risk a bit, for now. - evt.wait() + if isinstance(evt, list): + [_evt.wait() for _evt in evt] + else: + evt.wait() def to_output_template(keys, _): name_in_program = self.output_id_to_name_in_program[keys] diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index f1630674..46dcfdf4 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -115,14 +115,15 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA assert len(args) == np.prod(newshape) orig_shape = args[0].shape - out = actx.np.stack(args) - outshape = tuple([newshape] + list(orig_shape)) - - if len(newshape) > 1: - # Reshape the object - out = out.reshape(outshape) - for i in range(len(newshape)): - out = out.with_tagged_axis(i, [study_name_tag(i, newshape[i])]) + out = actx.np.stack(args, axis=args[0].ndim) + outshape = tuple(list(orig_shape) + [newshape] ) + + #if len(newshape) > 1: + # # Reshape the object + # out = out.reshape(outshape) + + for i in range(len(orig_shape), len(outshape)): + out = out.with_tagged_axis(i, [study_name_tag(i - len(orig_shape), newshape[i-len(orig_shape)])]) return out @@ -140,6 +141,7 @@ def unpack_parameter_study(data: ArrayT, ndim: int = len(data.axes) out: Dict[int, List[ArrayT]] = {} + study_count = 0 for i in range(ndim): axis_tags = data.axes[i].tags_of_type(study_name_tag) if axis_tags: @@ -150,11 +152,12 @@ def unpack_parameter_study(data: ArrayT, tmp[i] = j the_slice: Tuple[slice] = tuple(tmp) # Needs to be a tuple of slices not list of slices. - if i in out.keys(): - out[i].append(data[the_slice]) + if study_count in out.keys(): + out[study_count].append(data[the_slice]) else: - out[i] = [data[the_slice]] - + out[study_count] = [data[the_slice]] + if study_count in out.keys(): + study_count += 1 # yield data[the_slice] return out diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 4ff17e23..b95d08e4 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -50,8 +50,14 @@ import numpy as np import pytato as pt +import loopy as lp from immutabledict import immutabledict + +from pytato.scalar_expr import IdentityMapper +import pymbolic.primitives as prim + + from pytools import memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag @@ -129,6 +135,7 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A def map_stack(self, expr: Stack) -> Array: + # TODO: Fix return super().map_stack(expr) def map_concatenate(self, expr: Concatenate) -> Array: @@ -149,7 +156,6 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) - breakpoint() axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i in range(len(expr.axis_permutation))]) # Include the axes we are adding to the system. @@ -199,8 +205,8 @@ def map_placeholder(self, expr: Placeholder) -> Array: axes=correct_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: - # TODO: Fix # Update bindings first. new_bindings: Mapping[str, Array] = { name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} @@ -210,43 +216,77 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: from pytools.obj_array import flat_obj_array all_axis_tags: Set[Tag] = set() + studies_by_variable: Mapping[str, Mapping[Tag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() + studies_by_variable[name] = {} for i in range(len(bnd.axes)): axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) - breakpoint() + for tag in axis_tags_for_bnd: + studies_by_variable[name][tag] = 1 all_axis_tags = all_axis_tags.union(axis_tags_for_bnd) # Freeze the set now. all_axis_tags = frozenset(all_axis_tags) - - breakpoint() active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) axes: Optional[Tuple[Axis]] = tuple([]) study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {} - count = 0 - new_shape: List[int] = [0 for i in - range(len(active_studies) + len(expr.shape))] - new_axes: List[Axis] = [Axis() for i in range(len(new_shape))] + new_shape = expr.shape + new_axes = expr.axes for study in active_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming - study_to_axis_number[type(study)] = count - new_shape[count] = study.axis_size # We are recording the size of each parameter study. - new_axes[count] = new_axes[count].tagged([study]) - count += 1 + # The active studies are added to the end of the bindings. + study_to_axis_number[study] = len(new_shape) + new_shape = new_shape + (study.axis_size,) + new_axes = new_axes + (Axis(tags=frozenset((study,))),) + # This assumes that the axis only has 1 tag, + # because there should be no dependence across instances. + + # Now we need to update the expressions. + scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number) + + return IndexLambda(expr=scalar_expr, + bindings=type(expr.bindings)(new_bindings), + shape=new_shape, + var_to_reduction_descr=expr.var_to_reduction_descr, + dtype=expr.dtype, + axes=new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - for i in range(len(expr.shape)): - new_shape[count] = expr.shape[i] - count += 1 - new_shape: Tuple[int] = tuple(new_shape) - breakpoint() - return super().map_index_lambda(expr) +class ParamAxisExpander(IdentityMapper): + + def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, Mapping[ParameterStudyAxisTag, bool]], + study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): + # We know that we are not changing the variable that we are indexing into. + # This is stored in the aggregate member of the class Subscript. + + # We only need to modify the indexing which is stored in the index member. + name = expr.aggregate.name + if name in studies_by_variable.keys(): + # These are the single instance information. + index = self.rec(expr.index, studies_by_variable, + study_to_axis_number) + + new_vars: Tuple[prim.Variable] = tuple([]) + + for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): + if key in studies_by_variable[name]: + new_vars = new_vars + (prim.Variable(f"_{study_to_axis_number[key]}"),) + + if isinstance(index, tuple): + index = index + new_vars + else: + index = tuple(index) + new_vars + return type(expr)(aggregate=expr.aggregate, index=index) + return expr + # {{{ ParamStudyPytatoPyOpenCLArrayContext @@ -267,6 +307,10 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) + def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + # Update in a subclass if you want. + return t_unit + # }}} @@ -338,18 +382,21 @@ def _as_dict_of_named_arrays(keys, ary): input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()} input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies + breakpoint() + dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) + breakpoint() dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. # Use the normal compiler now. - - compiled_func = self._dag_to_compiled_func( - pt.make_dict_of_named_arrays(dict_of_named_arrays), + + compiled_func = self._dag_to_compiled_func(dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) + breakpoint() self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) @@ -364,20 +411,15 @@ def _cut_if_in_param_study(name, arg) -> Array: """ ndim: int = len(arg.shape) newshape = [] - update_tags: set = set() update_axes = [] for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) - if axis_tags: - # We need to remove those tags. - update_tags.add(axis_tags) - else: + if not axis_tags: update_axes.append(arg.axes[i]) newshape.append(arg.shape[i]) - update_tags.update(arg.tags) update_axes = tuple(update_axes) - update_tags = list(update_tags)[0] # Return just the frozenset. + update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) diff --git a/examples/parameter_study.py b/examples/parameter_study.py index ef9087bf..78adc2c5 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -15,19 +15,8 @@ queue = cl.CommandQueue(ctx) actx = ParamStudyPytatoPyOpenCLArrayContext(queue) - -# Eq: z = x + y -# Assumptions: x and y are undergoing independent parameter studies. -base_shape = (15, 5) -def rhs(param1, param2): - return param1 + param2 - return pt.transpose(param1) - return pt.roll(param1, shift=3, axis=1) - return param1.reshape(np.prod(base_shape)) - return param1 + param2 - - # Experimental setup +base_shape = (15, 5) seed = 12345 rng = np.random.default_rng(seed) x = actx.from_numpy(rng.random(base_shape)) @@ -39,41 +28,39 @@ def rhs(param1, param2): y2 = actx.from_numpy(rng.random(base_shape)) y3 = actx.from_numpy(rng.random(base_shape)) +# Eq: z = x + y +# Assumptions: x and y are undergoing independent parameter studies. +def rhs(param1, param2): + return param1 + param2 @dataclass(frozen=True) class ParameterStudyForX(ParameterStudyAxisTag): pass - @dataclass(frozen=True) class ParameterStudyForY(ParameterStudyAxisTag): pass - # Pack a parameter study of 3 instances for both x and y. -# We are assuming these are distinct parameter studies. packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2) packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3) -output_x = unpack_parameter_study(packx, ParameterStudyForX) - -print(packx) compiled_rhs = actx.compile(rhs) # Build the function caller # Builds a trace for a single instance of evaluating the RHS and -# then converts it to a program which takes our multiple instances -# of `x` and `y`. -breakpoint() +# then converts it to a program which takes our multiple instances of `x` and `y`. output = compiled_rhs(packx, packy) +breakpoint() +output_2 = compiled_rhs(x,y) -assert output.shape == (3, 4, 15, 5) # Distinct parameter studies. +assert output.shape == (15, 5, 3, 4) # Distinct parameter studies. output_x = unpack_parameter_study(output, ParameterStudyForX) output_y = unpack_parameter_study(output, ParameterStudyForY) assert len(output_x) == 1 # Number of parameter studies involving "x" assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study # All outputs across every other parameter study. -assert output_x[0][0].shape == (4, 15, 5) +assert output_x[0][0].shape == (15, 5, 4) assert len(output_y) == 1 assert len(output_y[0]) == 4 -assert output_y[0][0].shape == (3, 15, 5) +assert output_y[0][0].shape == (15, 5, 3)