diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 78004d2b..cffa27ed 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -58,7 +58,6 @@ import numpy as np -import pytato as pt from pytato.array import Array from arraycontext.context import ArrayContext @@ -98,7 +97,7 @@ def pack_for_parameter_study(actx: ArrayContext, orig_shape = args[0].shape out = actx.np.stack(args, axis=len(args[0].shape)) - outshape = tuple([*list(orig_shape), newshape]) + outshape = *orig_shape, newshape # if len(newshape) > 1: # # Reshape the object diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 19921f56..b945aae8 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -85,6 +85,9 @@ ) +ArraysT = Tuple[Array, ...] + + @dataclass(frozen=True) class ParameterStudyAxisTag(UniqueTag): """ @@ -100,6 +103,9 @@ class ParameterStudyAxisTag(UniqueTag): axis_size: int +StudiesT = Tuple[ParameterStudyAxisTag, ...] + + class ExpansionMapper(CopyMapper): def __init__(self, actual_input_shapes: Mapping[str, ShapeType], @@ -108,8 +114,7 @@ def __init__(self, actual_input_shapes: Mapping[str, ShapeType], self.actual_input_shapes = actual_input_shapes self.actual_input_axes = actual_input_axes - def single_predecessor_updates(self, - curr_expr: Array, + def single_predecessor_updates(self, curr_expr: Array, new_expr: Array) -> Tuple[ShapeType, AxesT]: # Initialize with something for the typing. @@ -125,8 +130,8 @@ def single_predecessor_updates(self, for _j, tag in enumerate(axis_tags): # Should be relatively few tags on each axis $O(1)$. if isinstance(tag, ParameterStudyAxisTag): - new_axes = new_axes + (new_expr.axes[i],) - shape_to_append = shape_to_append + (new_expr.shape[i],) + new_axes = *new_axes, new_expr.axes[i], + shape_to_append = *shape_to_append, new_expr.shape[i], if already_added: raise ValueError("An individual axis may only be " + "tagged with one ParameterStudyAxisTag.") @@ -137,8 +142,7 @@ def single_predecessor_updates(self, def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - _, new_axes = self.single_predecessor_updates(expr.array, - new_array) + _, new_axes = self.single_predecessor_updates(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, axis=expr.axis, @@ -148,8 +152,7 @@ def map_roll(self, expr: Roll) -> Array: def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.single_predecessor_updates(expr.array, - new_array) + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) # Include the axes we are adding to the system. axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) for i in range(len(postpend_shape))]) @@ -161,9 +164,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: + breakpoint() new_array = self.rec(expr.array) - _, new_axes = self.single_predecessor_updates(expr.array, - new_array) + _, new_axes = self.single_predecessor_updates(expr.array, new_array) return type(expr)(new_array, indices=self.rec_idx_or_size_tuple(expr.indices), # May need to modify indices @@ -173,10 +176,9 @@ def _map_index_base(self, expr: IndexBase) -> Array: def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.single_predecessor_updates(expr.array, - new_array) + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) return Reshape(new_array, - newshape=self.rec_idx_or_size_tuple(expr.newshape + + newshape=self.rec_idx_or_size_tuple(expr.newshape + \ postpend_shape), order=expr.order, axes=expr.axes + new_axes, @@ -202,13 +204,14 @@ def map_placeholder(self, expr: Placeholder) -> Array: # {{{ Operations with multiple predecessors. def _studies_from_multiple_pred(self, - new_arrays: Tuple[Array, ...]) -> Tuple[AxesT, - Set[ParameterStudyAxisTag], - Dict[Array, Tuple[ParameterStudyAxisTag, ...]]]: + new_arrays: ArraysT) -> Tuple[AxesT, + Set[ParameterStudyAxisTag], + Dict[Array, + StudiesT]]: - new_axes_for_end: Tuple[Axis, ...] = () + new_axes_for_end: AxesT = () cur_studies: Set[ParameterStudyAxisTag] = set() - studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag, ...]] = {} + studies_by_array: Dict[Array, StudiesT] = {} for _ind, array in enumerate(new_arrays): for axis in array.axes: @@ -224,7 +227,7 @@ def _studies_from_multiple_pred(self, if axis_tags[0] not in cur_studies: cur_studies.add(axis_tags[0]) - new_axes_for_end = new_axes_for_end + (axis,) + new_axes_for_end = *new_axes_for_end, axis return new_axes_for_end, cur_studies, studies_by_array @@ -246,17 +249,17 @@ def map_concatenate(self, expr: Concatenate) -> Array: tags=expr.tags, non_equality_tags=expr.non_equality_tags) - def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[Array, ...], - AxesT]: + def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT, + AxesT]: - sing_inst_in_shape = expr.arrays[0].shape + one_inst_in_shape = expr.arrays[0].shape new_arrays = tuple(self.rec(arr) for arr in expr.arrays) _, cur_studies, studies_by_array = self._studies_from_multiple_pred(new_arrays) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - new_shape_of_predecessors = sing_inst_in_shape + new_shape_of_predecessors = one_inst_in_shape new_axes = expr.axes for study in cur_studies: @@ -264,9 +267,9 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[ # Just defensive programming # The active studies are added to the end of the bindings. study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = new_shape_of_predecessors + \ + new_shape_of_predecessors = *new_shape_of_predecessors, \ (study.axis_size,) - new_axes = new_axes + (Axis(tags=frozenset((study,))),) + new_axes = *new_axes, Axis(tags=frozenset((study,))), # This assumes that the axis only has 1 tag, # because there should be no dependence across instances. @@ -275,41 +278,40 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[ # Now we need to update the expressions. # Now that we have the appropriate shape, # we need to update the input arrays to match. - + cp_map = CopyMapper() - corrected_new_arrays: Tuple[Array, ...] = () + corrected_new_arrays: ArraysT = () for _, array in enumerate(new_arrays): tmp = cp_map(array) # Get a copy of the array. if len(array.axes) < len(new_axes): # We need to grow the array to the new size. for study in cur_studies: if study not in studies_by_array[array]: - build: Tuple[Array, ...] = tuple([cp_map(tmp) for + build: ArraysT = tuple([cp_map(tmp) for _ in range(study.axis_size)]) tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=tmp.axes + - (Axis(tags=frozenset((study,))),), + axes=(*tmp.axes, Axis(tags=frozenset((study,)))), tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) elif len(array.axes) > len(new_axes): - raise ValueError("Input array is too big. " + \ - f"Expected at most: {len(new_axes)} " + \ + raise ValueError("Input array is too big. " + + f"Expected at most: {len(new_axes)} " + f"Found: {len(array.axes)} axes.") # Now we need to correct to the appropriate shape with an axis permutation. # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(sing_inst_in_shape))]) + permute: Tuple[int, ...] = tuple([i for i in range(len(one_inst_in_shape))]) for _, axis in enumerate(tmp.axes): axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) if axis_tags: assert len(axis_tags) == 1 - permute = permute + (study_to_axis_number[axis_tags[0]],) + permute = *permute, study_to_axis_number[axis_tags[0]], assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = corrected_new_arrays + \ - (AxisPermutation(tmp, permute, tags=tmp.tags, + corrected_new_arrays = *corrected_new_arrays, \ + AxisPermutation(tmp, permute, tags=tmp.tags, axes=tmp.axes, - non_equality_tags=tmp.non_equality_tags),) + non_equality_tags=tmp.non_equality_tags), return corrected_new_arrays, new_axes @@ -322,18 +324,18 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # Determine the new parameter studies that are being conducted. from pytools import unique - all_axis_tags: Tuple[ParameterStudyAxisTag, ...] = () - studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {} + all_axis_tags: StudiesT = () + varname_to_studies: Dict[str, Dict[UniqueTag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() - studies_by_variable[name] = {} + varname_to_studies[name] = {} for i in range(len(bnd.axes)): axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) for tag in axis_tags_for_bnd: if isinstance(tag, ParameterStudyAxisTag): # Defense - studies_by_variable[name][tag] = True - all_axis_tags = all_axis_tags + (tag,) + varname_to_studies[name][tag] = True + all_axis_tags = *all_axis_tags, tag, cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} @@ -346,13 +348,13 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # Just defensive programming # The active studies are added to the end of the bindings. study_to_axis_number[study] = len(new_shape) - new_shape = new_shape + (study.axis_size,) - new_axes = new_axes + (Axis(tags=frozenset((study,))),) + new_shape = *new_shape, study.axis_size, + new_axes = *new_axes, Axis(tags=frozenset((study,))), # This assumes that the axis only has 1 tag, # because there should be no dependence across instances. # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, + scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies, study_to_axis_number) return IndexLambda(expr=scalar_expr, @@ -380,7 +382,7 @@ def map_einsum(self, expr: Einsum) -> Array: # Just defensive programming # The active studies are added to the end. study_to_axis_number[study] = len(new_shape) - new_shape = new_shape + (study.axis_size,) + new_shape = *new_shape, study.axis_size, for ind, array in enumerate(new_arrays): for _, axis in enumerate(array.axes): @@ -403,25 +405,25 @@ def map_einsum(self, expr: Einsum) -> Array: class ParamAxisExpander(IdentityMapper): def map_subscript(self, expr: prim.Subscript, - studies_by_variable: Mapping[str, - Mapping[ParameterStudyAxisTag, bool]], + varname_to_studies: Mapping[str, + Mapping[ParameterStudyAxisTag, bool]], study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): # We know that we are not changing the variable that we are indexing into. # This is stored in the aggregate member of the class Subscript. # We only need to modify the indexing which is stored in the index member. name = expr.aggregate.name - if name in studies_by_variable.keys(): + if name in varname_to_studies.keys(): # These are the single instance information. - index = self.rec(expr.index, studies_by_variable, + index = self.rec(expr.index, varname_to_studies, study_to_axis_number) new_vars: Tuple[prim.Variable, ...] = () for key, num in sorted(study_to_axis_number.items(), key=lambda item: item[1]): - if key in studies_by_variable[name]: - new_vars = new_vars + (prim.Variable(f"_{num}"),) + if key in varname_to_studies[name]: + new_vars = *new_vars, prim.Variable(f"_{num}"), if isinstance(index, tuple): index = index + new_vars @@ -559,7 +561,7 @@ def _cut_if_in_param_study(name, arg) -> Array: for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) if not axis_tags: - update_axes = update_axes + (arg.axes[i],) + update_axes = *update_axes, arg.axes[i], newshape.append(arg.shape[i]) # remove the first one that was placed there for typing. update_axes = update_axes[1:]