diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 46dcfdf4..cbe68cc6 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -59,11 +59,15 @@ Union, Sequence, List, + Iterable, + Mapping, ) import numpy as np import pytato as pt +from pytato.array import Array + from pytools import memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag @@ -74,7 +78,7 @@ from arraycontext.container import ArrayContainer, is_array_container_type -from arraycontext.context import ArrayT, ArrayContext +from arraycontext.context import ArrayContext from arraycontext.metadata import NameHint from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, @@ -99,9 +103,10 @@ logger = logging.getLogger(__name__) -def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyAxisTag, +def pack_for_parameter_study(actx: ArrayContext, + study_name_tag_type: Type[ParameterStudyAxisTag], newshape: Tuple[int, ...], - *args: ArrayT) -> ArrayT: + *args: Array) -> Array: """ Args is a list of variable names and the realized input data that needs to be packed for a parameter study or uncertainty quantification. @@ -115,7 +120,7 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA assert len(args) == np.prod(newshape) orig_shape = args[0].shape - out = actx.np.stack(args, axis=args[0].ndim) + out = actx.np.stack(args, axis=len(args[0].shape)) outshape = tuple(list(orig_shape) + [newshape] ) #if len(newshape) > 1: @@ -123,34 +128,34 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA # out = out.reshape(outshape) for i in range(len(orig_shape), len(outshape)): - out = out.with_tagged_axis(i, [study_name_tag(i - len(orig_shape), newshape[i-len(orig_shape)])]) + out = out.with_tagged_axis(i, [study_name_tag_type(i - len(orig_shape), newshape[i-len(orig_shape)])]) return out -def unpack_parameter_study(data: ArrayT, - study_name_tag: ParameterStudyAxisTag) -> Dict[int, - List[ArrayT]]: +def unpack_parameter_study(data: Array, + study_name_tag_type: Type[ParameterStudyAxisTag]) -> Mapping[int, + List[Array]]: """ Split the data array along the axes which vary according to a ParameterStudyAxisTag - whose name tag is an instance study_name_tag. + whose name tag is an instance study_name_tag_type. output[i] corresponds to the values associated with the ith parameter study that - uses the variable name :arg: `study_name_tag`. + uses the variable name :arg: `study_name_tag_type`. """ - ndim: int = len(data.axes) - out: Dict[int, List[ArrayT]] = {} + ndim: int = len(data.shape) + out: Dict[int, List[Array]] = {} study_count = 0 for i in range(ndim): - axis_tags = data.axes[i].tags_of_type(study_name_tag) + axis_tags = data.axes[i].tags_of_type(study_name_tag_type) if axis_tags: # Now we need to split this data. breakpoint() for j in range(data.shape[i]): - tmp: List[slice] = [slice(None)] * ndim + tmp: List[Any] = [slice(None)] * ndim tmp[i] = j - the_slice: Tuple[slice] = tuple(tmp) + the_slice = tuple(tmp) # Needs to be a tuple of slices not list of slices. if study_count in out.keys(): out[study_count].append(data[the_slice]) diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index b95d08e4..aeef7245 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -1,3 +1,4 @@ +from __future__ import annotations """ .. currentmodule:: arraycontext @@ -46,6 +47,7 @@ Sequence, List, Mapping, + Set, ) import numpy as np @@ -84,7 +86,7 @@ DataWrapper, SizeParam, DictOfNamedArrays, AbstractResultWithNamedArrays, Reshape, Concatenate, NamedArray, IndexRemappingBase, Einsum, InputArgumentBase, AdvancedIndexInNoncontiguousAxes, IndexBase, DataInterface, - Axis) + Axis, ShapeType) from pytato.utils import broadcast_binary_op @@ -115,23 +117,30 @@ def __init__(self, actual_input_shapes: Mapping[str, Tuple[int,...]], def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, - new_expr: Array) -> Tuple[Optional[Tuple[int]], - Optional[Tuple[Axis]]]: - shape_to_prepend: Tuple[int] = tuple([]) - new_axes: Tuple[Axis] = tuple([]) + new_expr: Array) -> Tuple[ShapeType, + Tuple[Axis,...]]: + # Initialize with something for the typing. + shape_to_append: ShapeType= (-1,) + new_axes: Tuple[Axis,...] = (Axis(tags=frozenset()),) if curr_expr.shape == new_expr.shape: - return shape_to_prepend, new_axes + return shape_to_append, new_axes # Now we may need to change. changed = False for i in range(len(new_expr.axes)): axis_tags = list(new_expr.axes[i].tags) + already_added = False for j, tag in enumerate(axis_tags): # Should be relatively few tags on each axis $O(1)$. if isinstance(tag, ParameterStudyAxisTag): new_axes = new_axes + (new_expr.axes[i],) - shape_to_prepend = shape_to_prepend + (new_expr.shape[i],) - return shape_to_prepend, new_axes + shape_to_append = shape_to_append + (new_expr.shape[i],) + if already_added: + raise ValueError("An individual axis may only be tagged with one ParameterStudyAxisTag.") + already_added = True + + # Remove initialized extraneous data + return shape_to_append[1:], new_axes[1:] def map_stack(self, expr: Stack) -> Array: @@ -143,50 +152,49 @@ def map_concatenate(self, expr: Concatenate) -> Array: def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + _, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, - axis=expr.axis + len(new_axes), - axes=new_axes + expr.axes, + axis=expr.axis, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) - axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i - in range(len(expr.axis_permutation))]) # Include the axes we are adding to the system. - axis_permute = tuple([i for i in range(len(prepend_shape))]) + axis_permute + axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) + for i in range(len(postpend_shape))]) return AxisPermutation(array=new_array, axis_permutation=axis_permute, - axes=new_axes + expr.axes, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return type(expr)(new_array, indices=self.rec_idx_or_size_tuple(expr.indices), # May need to modify indices - axes=new_axes + expr.axes, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags = expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return Reshape(new_array, - newshape = self.rec_idx_or_size_tuple(prepend_shape + expr.newshape), + newshape = self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), order=expr.order, - axes=new_axes + expr.axes, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -198,7 +206,7 @@ def map_placeholder(self, expr: Placeholder) -> Array: # We may need to update the size. if expr.shape != self.actual_input_shapes[expr.name]: correct_shape = self.actual_input_shapes[expr.name] - correct_axes = self.actual_input_axes[expr.name] + correct_axes = tuple(self.actual_input_axes[expr.name]) return Placeholder(name=expr.name, shape=self.rec_idx_or_size_tuple(correct_shape), dtype=expr.dtype, @@ -208,30 +216,30 @@ def map_placeholder(self, expr: Placeholder) -> Array: def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. - new_bindings: Mapping[str, Array] = { name: self.rec(bnd) + new_bindings: Dict[str, Array] = { name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} # Determine the new parameter studies that are being conducted. from pytools import unique from pytools.obj_array import flat_obj_array - all_axis_tags: Set[Tag] = set() - studies_by_variable: Mapping[str, Mapping[Tag, bool]] = {} + all_axis_tags: Tuple[ParameterStudyAxisTag,...] = () + studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() studies_by_variable[name] = {} for i in range(len(bnd.axes)): axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) for tag in axis_tags_for_bnd: - studies_by_variable[name][tag] = 1 - all_axis_tags = all_axis_tags.union(axis_tags_for_bnd) + if isinstance(tag, ParameterStudyAxisTag): + # Defense + studies_by_variable[name][tag] = True + all_axis_tags = all_axis_tags + (tag,) - # Freeze the set now. - all_axis_tags = frozenset(all_axis_tags) active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) - axes: Optional[Tuple[Axis]] = tuple([]) - study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {} + axes: Tuple[Axis, ...] = () + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} count = 0 new_shape = expr.shape @@ -251,7 +259,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number) return IndexLambda(expr=scalar_expr, - bindings=type(expr.bindings)(new_bindings), + bindings=immutabledict(new_bindings), shape=new_shape, var_to_reduction_descr=expr.var_to_reduction_descr, dtype=expr.dtype, @@ -274,7 +282,7 @@ def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, index = self.rec(expr.index, studies_by_variable, study_to_axis_number) - new_vars: Tuple[prim.Variable] = tuple([]) + new_vars: Tuple[prim.Variable, ...] = () for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): if key in studies_by_variable[name]: @@ -379,19 +387,25 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()} - input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} + #input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")} + #input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} + input_shapes = {} + input_axes = {} + for i in arg_id_to_descr.keys(): + if hasattr(arg_id_to_descr[i], "shape"): + input_shapes[input_id_to_name_in_program[i]] = arg_id_to_descr[i].shape + input_axes[input_id_to_name_in_program[i]] = arg_id_to_arg[i].axes myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies breakpoint() - dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) breakpoint() - dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. # Use the normal compiler now. - compiled_func = self._dag_to_compiled_func(dict_of_named_arrays, + compiled_func = self._dag_to_compiled_func(myMapper(pt_dict_of_named_arrays), # Update the arrays + #pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) @@ -411,14 +425,13 @@ def _cut_if_in_param_study(name, arg) -> Array: """ ndim: int = len(arg.shape) newshape = [] - update_axes = [] + update_axes: Tuple[Axis, ...] = (Axis(tags=frozenset()),) for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) if not axis_tags: - update_axes.append(arg.axes[i]) + update_axes = update_axes + (arg.axes[i],) newshape.append(arg.shape[i]) - - update_axes = tuple(update_axes) + update_axes = update_axes[1:] # remove the first one that was placed there for typing. update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags)