From e697300cf96f48b9a4f548a174a80890fe91cccc Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 1 Jul 2024 17:45:41 -0500 Subject: [PATCH 01/25] Add a packer which tags the axes as uncertain. --- arraycontext/impl/pytato/__init__.py | 95 +++++++++++++++++++++++++++- examples/uncertain_prop.py | 49 ++++++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 examples/uncertain_prop.py diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index a32e8de0..58201b6c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -10,6 +10,7 @@ Following :mod:`pytato`-based array context are provided: .. autoclass:: PytatoPyOpenCLArrayContext +.. autoclass:: PytatoPyOpenCLArrayContextUQ .. autoclass:: PytatoJAXArrayContext @@ -50,13 +51,14 @@ import numpy as np from pytools import memoize_method -from pytools.tag import Tag, ToTagSetConvertible, normalize_tags +from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag from arraycontext.container.traversal import ( rec_map_array_container, with_array_context) from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike from arraycontext.metadata import NameHint +from dataclasses import dataclass if TYPE_CHECKING: import pyopencl as cl @@ -684,6 +686,97 @@ def clone(self): # }}} +@dataclass(frozen=True) +class UQAxisTag(UniqueTag): + """ + A tag for acting on axes of arrays. + """ + uq_instance_num: str + + +# {{{ PytatoPyOpenCLArrayContextUQ + + +class PytatoPyOpenCLArrayContextUQ(PytatoPyOpenCLArrayContext): + """ + A derived class for PytatoPyOpenCLArrayContext updated for the + purpose of enabling parameter studies and uncertainty quantification. + + .. automethod:: __init__ + + .. automethod:: transform_dag + + .. automethod:: compile + """ + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + # TODO: Update to a new compiler potentially + from .compile import LazilyPyOpenCLCompilingFunctionCaller + return LazilyPyOpenCLCompilingFunctionCaller(self, f) + + def transform_dag(self, dag: "pytato.DictOfNamedArrays" + ) -> "pytato.DictOfNamedArrays": + import pytato as pt + # TODO: This gets called before generating the placeholders + dag = pt.transform.materialize_with_mpms(dag) + return dag + + def pack_for_uq(self,*args): + """ + Args is a list of variable names and the realized input data that needs + to be packed for a parameter study or uncertainty quantification. + + Args needs to be in the format + ["v", v0, v1, v2, ..., vN, "w", w0, w1, w2, ..., wM, \dots] + + where "v" and "w" would be the variable names in your program. + If you want to include a constant just pass the var name and then + the value in the next argument. + + Returns a dictionary of {var name: stacked array} + """ + from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace + + assert len(args) > 0 + out = {} + curr_var = str(args[0]) + + num_calls = 0 + + for ind in range(1, len(args)): + val = args[ind] + if isinstance(val, str): + # Done with previous. + if val in out.keys(): + raise ValueError("Repeated definitions of variable: " + str(val) \ + + " Defined Variables: " + list(out.keys())) + out[curr_var] = PytatoFakeNumpyNamespace.stack(self, out[curr_var]) + + if out[curr_var].shape[0] > 1: + # Tag the outer axis as uncertain. + tag_name = str(num_calls) + " var: " + str(curr_var) + out[curr_var] = out[curr_var].with_tagged_axis(0, [UQAxisTag(tag_name)]) + num_calls += 1 + curr_var = val + + elif curr_var in out.keys(): + out[curr_var].append(val) + else: + out[curr_var] = [val] + + # Handle the last variable group. + out[curr_var] = PytatoFakeNumpyNamespace.stack(self, out[curr_var]) + + if out[curr_var].shape[0] > 1: + # Tag the outer axis as uncertain. + tag_name = str(num_calls) + " var: " + str(curr_var) + out[curr_var] = out[curr_var].with_tagged_axis(0, [UQAxisTag(tag_name)]) + num_calls += 1 + + return out + +# }}} + + # {{{ PytatoJAXArrayContext class PytatoJAXArrayContext(_BasePytatoArrayContext): diff --git a/examples/uncertain_prop.py b/examples/uncertain_prop.py new file mode 100644 index 00000000..35fa8e2e --- /dev/null +++ b/examples/uncertain_prop.py @@ -0,0 +1,49 @@ +import arraycontext +from dataclasses import dataclass + +import numpy as np # for the data types +from pytools.tag import Tag + +from arraycontext.impl.pytato.__init__ import (PytatoPyOpenCLArrayContextUQ, + PytatoPyOpenCLArrayContext) + +# The goal of this file is to propagate the uncertainty in an array to the output. + + +my_context = arraycontext.impl.pytato.PytatoPyOpenCLArrayContext +a = my_context.zeros(my_context, shape=(5,5), dtype=np.int32) + 2 + +b = my_context.zeros(my_context, (5,5), np.int32) + 15 + +print(a) +print("========================================================") +print(b) +print("========================================================") +breakpoint() + +# Eq: z = x + y +# Assumptions: x and y are independently uncertain. + + +x = np.random.random((15,5)) +x1 = np.random.random((15,5)) +x2 = np.random.random((15,5)) + +y = np.random.random((15,5)) +y1 = np.random.random((15,5)) +y2 = np.random.random((15,5)) + + +actx = PytatoPyOpenCLArrayContextUQ + +out = actx.pack_for_uq(actx,"x", x, x1, x2, "y", y, y1, y2) +print("===============out======================") +print(out) + +breakpoint() + + + + + + From 8071f8aaae9edc196da3b75e437bd319f5fa9a2b Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 8 Jul 2024 09:22:12 -0500 Subject: [PATCH 02/25] Prototype for interface --- arraycontext/impl/pytato/__init__.py | 34 +++++++++++++++++++++++++++- examples/uncertain_prop.py | 11 ++++----- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 58201b6c..09c3def0 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -720,7 +720,7 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays" dag = pt.transform.materialize_with_mpms(dag) return dag - def pack_for_uq(self,*args): + def pack_for_uq(self,*args) -> dict: """ Args is a list of variable names and the realized input data that needs to be packed for a parameter study or uncertainty quantification. @@ -774,6 +774,38 @@ def pack_for_uq(self,*args): return out + + def unpack(self, data): + """ + Revert data to a sequence of outputs under the assumption that a specific variable + is held constant. + + ::arg:: data multidimensional array tagged with dimensions that are varying. + UQAxisTag will tag each specific axis that we are going to slice. + """ + + ndim = len(data.axes) + + out = {} + + + for i in range(ndim): + axis_tags = data.axes[i].tags_of_type(UQAxisTag) + if axis_tags: + # Now we need to split this data. + for j in range(len(data.axis[i])): + the_slice = [slice(None)] * ndim + the_slice[i] = j + if i in out.keys(): + out[i].append(data[the_slice]) + else: + out[i] = data[the_slice] + #yield data[the_slice] + + + return out + + # }}} diff --git a/examples/uncertain_prop.py b/examples/uncertain_prop.py index 35fa8e2e..67afeaea 100644 --- a/examples/uncertain_prop.py +++ b/examples/uncertain_prop.py @@ -19,7 +19,6 @@ print("========================================================") print(b) print("========================================================") -breakpoint() # Eq: z = x + y # Assumptions: x and y are independently uncertain. @@ -37,13 +36,13 @@ actx = PytatoPyOpenCLArrayContextUQ out = actx.pack_for_uq(actx,"x", x, x1, x2, "y", y, y1, y2) -print("===============out======================") +print("===============OUT======================") print(out) -breakpoint() - - - +x = out["x"] +y = out["y"] +breakpoint() +x + y From 89ff83c0736fc644bac6d52b6d56b74e3cf9e7ce Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Mon, 8 Jul 2024 16:55:24 -0500 Subject: [PATCH 03/25] Move the parameter study definitions to their own folder and update the example. --- arraycontext/__init__.py | 2 +- arraycontext/impl/pytato/__init__.py | 121 --------- arraycontext/parameter_study/__init__.py | 319 +++++++++++++++++++++++ examples/uncertain_prop.py | 54 ++-- 4 files changed, 359 insertions(+), 137 deletions(-) create mode 100644 arraycontext/parameter_study/__init__.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1d0efb36..327fec66 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -88,7 +88,7 @@ pytest_generate_tests_for_pyopencl_array_context, ) from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag - +from .parameter_study import ParamStudyPytatoPyOpenCLArrayContext, pack_for_parameter_study, unpack_parameter_study __all__ = ( "Array", diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 5ffc5485..78caa5c6 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -701,127 +701,6 @@ def clone(self): # }}} -@dataclass(frozen=True) -class UQAxisTag(UniqueTag): - """ - A tag for acting on axes of arrays. - """ - uq_instance_num: str - - -# {{{ PytatoPyOpenCLArrayContextUQ - - -class PytatoPyOpenCLArrayContextUQ(PytatoPyOpenCLArrayContext): - """ - A derived class for PytatoPyOpenCLArrayContext updated for the - purpose of enabling parameter studies and uncertainty quantification. - - .. automethod:: __init__ - - .. automethod:: transform_dag - - .. automethod:: compile - """ - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: - # TODO: Update to a new compiler potentially - from .compile import LazilyPyOpenCLCompilingFunctionCaller - return LazilyPyOpenCLCompilingFunctionCaller(self, f) - - def transform_dag(self, dag: "pytato.DictOfNamedArrays" - ) -> "pytato.DictOfNamedArrays": - import pytato as pt - # TODO: This gets called before generating the placeholders - dag = pt.transform.materialize_with_mpms(dag) - return dag - - def pack_for_uq(self,*args) -> dict: - """ - Args is a list of variable names and the realized input data that needs - to be packed for a parameter study or uncertainty quantification. - - Args needs to be in the format - ["v", v0, v1, v2, ..., vN, "w", w0, w1, w2, ..., wM, \dots] - - where "v" and "w" would be the variable names in your program. - If you want to include a constant just pass the var name and then - the value in the next argument. - - Returns a dictionary of {var name: stacked array} - """ - from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace - - assert len(args) > 0 - out = {} - curr_var = str(args[0]) - - num_calls = 0 - - for ind in range(1, len(args)): - val = args[ind] - if isinstance(val, str): - # Done with previous. - if val in out.keys(): - raise ValueError("Repeated definitions of variable: " + str(val) \ - + " Defined Variables: " + list(out.keys())) - out[curr_var] = PytatoFakeNumpyNamespace.stack(self, out[curr_var]) - - if out[curr_var].shape[0] > 1: - # Tag the outer axis as uncertain. - tag_name = str(num_calls) + " var: " + str(curr_var) - out[curr_var] = out[curr_var].with_tagged_axis(0, [UQAxisTag(tag_name)]) - num_calls += 1 - curr_var = val - - elif curr_var in out.keys(): - out[curr_var].append(val) - else: - out[curr_var] = [val] - - # Handle the last variable group. - out[curr_var] = PytatoFakeNumpyNamespace.stack(self, out[curr_var]) - - if out[curr_var].shape[0] > 1: - # Tag the outer axis as uncertain. - tag_name = str(num_calls) + " var: " + str(curr_var) - out[curr_var] = out[curr_var].with_tagged_axis(0, [UQAxisTag(tag_name)]) - num_calls += 1 - - return out - - - def unpack(self, data): - """ - Revert data to a sequence of outputs under the assumption that a specific variable - is held constant. - - ::arg:: data multidimensional array tagged with dimensions that are varying. - UQAxisTag will tag each specific axis that we are going to slice. - """ - - ndim = len(data.axes) - - out = {} - - - for i in range(ndim): - axis_tags = data.axes[i].tags_of_type(UQAxisTag) - if axis_tags: - # Now we need to split this data. - for j in range(len(data.axis[i])): - the_slice = [slice(None)] * ndim - the_slice[i] = j - if i in out.keys(): - out[i].append(data[the_slice]) - else: - out[i] = data[the_slice] - #yield data[the_slice] - - - return out - - -# }}} # {{{ PytatoJAXArrayContext diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py new file mode 100644 index 00000000..1c366cf9 --- /dev/null +++ b/arraycontext/parameter_study/__init__.py @@ -0,0 +1,319 @@ +""" +.. currentmodule:: arraycontext + +A :mod:`pytato`-based array context defers the evaluation of an array until its +frozen. The execution contexts for the evaluations are specific to an +:class:`~arraycontext.ArrayContext` type. For ex. +:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to +JIT-compile and execute the array expressions. + +Following :mod:`pytato`-based array context are provided: + +.. autoclass:: ParamStudyPytatoPyOpenCLArrayContext +.. autoclass:: ParamStudyLazyPyOpenCLFunctionCaller + + +Compiling a Python callable (Internal) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: arraycontext.impl.pytato.compile +""" +__copyright__ = """ +Copyright (C) 2020-1 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import abc +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + FrozenSet, + Optional, + Tuple, + Type, + Union, + Sequence, +) + +import numpy as np + +from pytools import memoize_method +from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag + +from arraycontext.container.traversal import rec_map_array_container, with_array_context + +from arraycontext.context import ArrayT, ArrayContext +from arraycontext.metadata import NameHint +from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext +from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace +from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller + + +from dataclasses import dataclass + +if TYPE_CHECKING: + import pyopencl as cl + import pytato + +if getattr(sys, "_BUILDING_SPHINX_DOCS", False): + import pyopencl as cl + +import logging + + +logger = logging.getLogger(__name__) + +@dataclass(frozen=True) +class ParameterStudyAxisTag(UniqueTag): + """ + A tag for acting on axes of arrays. + """ + user_variable_name: str + axis_num: int + axis_size: int + +# {{{ ParamStudyPytatoPyOpenCLArrayContext + + +class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): + """ + A derived class for PytatoPyOpenCLArrayContext updated for the + purpose of enabling parameter studies and uncertainty quantification. + + .. automethod:: __init__ + + .. automethod:: transform_dag + + .. automethod:: compile + """ + + def transform_dag(self, ary): + # This going to be called before the compiler or freeze. + out = super().transform_dag(ary) + return out + + +# }}} + + +class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): + """ + Record a side-effect-free callable :attr:`f` which is initially designed for + to be called multiple times with different data. This class will update the + signature to allow :attr:`f` to be called once with the data for multiple + instances. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s + function application on *args*. + + Before applying :attr:`~BaseLazilyCompilingFunctionCaller.f`, it is compiled + to a :mod:`pytato` DAG that would apply + :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. + The intermediary pytato DAG for *args* is memoized in *self*. + """ + arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( + args, kwargs) + + try: + compiled_f = self.program_cache[arg_id_to_descr] + except KeyError: + pass + else: + return compiled_f(arg_id_to_arg) + + dict_of_named_arrays = {} + output_id_to_name_in_program = {} + input_id_to_name_in_program = { + arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + for arg_id in arg_id_to_arg} + + output_template = self.f( + *[_get_f_placeholder_args(arg, iarg, + input_id_to_name_in_program, self.actx) + for iarg, arg in enumerate(args)], + **{kw: _get_f_placeholder_args(arg, kw, + input_id_to_name_in_program, + self.actx) + for kw, arg in kwargs.items()}) + + self.actx._compile_trace_callback(self.f, "post_trace", output_template) + + if (not (is_array_container_type(output_template.__class__) + or isinstance(output_template, pt.Array))): + # TODO: We could possibly just short-circuit this interface if the + # returned type is a scalar. Not sure if it's worth it though. + raise NotImplementedError( + f"Function '{self.f.__name__}' to be compiled " + "did not return an array container or pt.Array," + f" but an instance of '{output_template.__class__}' instead.") + + def _as_dict_of_named_arrays(keys, ary): + name = "_pt_out_" + _ary_container_key_stringifier(keys) + output_id_to_name_in_program[keys] = name + dict_of_named_arrays[name] = ary + return ary + + rec_keyed_map_array_container(_as_dict_of_named_arrays, + output_template) + + compiled_func = self._dag_to_compiled_func( + pt.make_dict_of_named_arrays(dict_of_named_arrays), + input_id_to_name_in_program=input_id_to_name_in_program, + output_id_to_name_in_program=output_id_to_name_in_program, + output_template=output_template) + + self.program_cache[arg_id_to_descr] = compiled_func + return compiled_func(arg_id_to_arg) + + +def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): + """ + Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder` + in :meth:`LazilyCompilingFunctionCaller.__call__`. + + Preprocessing here refers to: + + - Metadata Inference that is supplied via *actx*\'s + :meth:`PytatoPyOpenCLArrayContext.transform_dag`. + """ + import pyopencl.array as cla + + from arraycontext.impl.pyopencl.taggable_cl_array import ( + TaggableCLArray, + to_tagged_cl_array, + ) + if isinstance(ary, pt.Array): + dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) + # Transform the DAG to give metadata inference a chance to do its job + return actx.transform_dag(dag)["_actx_out"].expr + elif isinstance(ary, TaggableCLArray): + return ary + elif isinstance(ary, cla.Array): + from warnings import warn + warn("Passing pyopencl.array.Array to a compiled callable" + " is deprecated and will stop working in 2023." + " Use `to_tagged_cl_array` to convert the array to" + " TaggableCLArray", DeprecationWarning, stacklevel=2) + + return to_tagged_cl_array(ary, + axes=None, + tags=frozenset()) + else: + raise NotImplementedError(type(ary)) + + +def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): + """ + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the + placeholder version of an argument to + :attr:`BaseLazilyCompilingFunctionCaller.f`. + """ + if np.isscalar(arg): + name = arg_id_to_name[(kw,)] + return pt.make_placeholder(name, (), np.dtype(type(arg))) + elif isinstance(arg, pt.Array): + name = arg_id_to_name[(kw,)] + # Transform the DAG to give metadata inference a chance to do its job + arg = _to_input_for_compiled(arg, actx) + return pt.make_placeholder(name, arg.shape, arg.dtype, + axes=arg.axes, + tags=arg.tags) + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + name = arg_id_to_name[(kw, *keys)] + # Transform the DAG to give metadata inference a chance to do its job + ary = _to_input_for_compiled(ary, actx) + return pt.make_placeholder(name, + ary.shape, + ary.dtype, + axes=ary.axes, + tags=ary.tags) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + +def pack_for_parameter_study(actx: ArrayContext, yourvarname: str, + newshape: Tuple[int, ...], + *args: ArrayT) -> ArrayT: + """ + Args is a list of variable names and the realized input data that needs + to be packed for a parameter study or uncertainty quantification. + + Args needs to be in the format + ["v", v0, v1, v2, ..., vN, "w", w0, w1, w2, ..., wM, \dots] + + where "v" and "w" would be the variable names in your program. + If you want to include a constant just pass the var name and then + the value in the next argument. + + Returns a dictionary of {var name: stacked array} + """ + + assert len(args) > 0 + assert len(args) == np.prod(newshape) + + out = {} + orig_shape = args[0].shape + out = actx.np.stack(args) + outshape = tuple([newshape] + [val for val in orig_shape]) + breakpoint() + + if len(newshape) > 1: + # Reshape the object + out = out.reshape(outshape) + for i in range(len(newshape)): + out = out.with_tagged_axis(i, [ParameterStudyAxisTag(yourvarname, i, newshape[i])]) + return out + + +def unpack_parameter_study(data: ArrayT, varname: str) -> Sequence[ArrayT]: + """ + Split the data array along the axes which vary according to a ParameterStudyAxisTag + whose variable name is varname. + """ + + ndim = len(data.axes) + out = {} + + for i in range(ndim): + axis_tags = data.axes[i].tags_of_type(ParameterStudyAxisTag) + if axis_tags: + # Now we need to split this data. + breakpoint() + for j in range(len(data.axis[i])): + the_slice = [slice(None)] * ndim + the_slice[i] = j + if i in out.keys(): + out[i].append(data[the_slice]) + else: + out[i] = data[the_slice] + #yield data[the_slice] + + return out diff --git a/examples/uncertain_prop.py b/examples/uncertain_prop.py index 67afeaea..4b1ed838 100644 --- a/examples/uncertain_prop.py +++ b/examples/uncertain_prop.py @@ -4,8 +4,7 @@ import numpy as np # for the data types from pytools.tag import Tag -from arraycontext.impl.pytato.__init__ import (PytatoPyOpenCLArrayContextUQ, - PytatoPyOpenCLArrayContext) +from arraycontext.impl.pytato.__init__ import (PytatoPyOpenCLArrayContext) # The goal of this file is to propagate the uncertainty in an array to the output. @@ -24,25 +23,50 @@ # Assumptions: x and y are independently uncertain. -x = np.random.random((15,5)) -x1 = np.random.random((15,5)) -x2 = np.random.random((15,5)) +base_shape = (15, 5) +x = np.random.random(base_shape) +x1 = np.random.random(base_shape) +x2 = np.random.random(base_shape) -y = np.random.random((15,5)) -y1 = np.random.random((15,5)) -y2 = np.random.random((15,5)) +y = np.random.random(base_shape) +y1 = np.random.random(base_shape) +y2 = np.random.random(base_shape) +y3 = np.random.random(base_shape) -actx = PytatoPyOpenCLArrayContextUQ +from arraycontext.parameter_study import (pack_for_parameter_study, + ParamStudyPytatoPyOpenCLArrayContext, unpack_parameter_study) +import pyopencl as cl -out = actx.pack_for_uq(actx,"x", x, x1, x2, "y", y, y1, y2) -print("===============OUT======================") -print(out) +ctx = cl.create_some_context(interactive=False) +queue = cl.CommandQueue(ctx) -x = out["x"] -y = out["y"] +actx = ParamStudyPytatoPyOpenCLArrayContext(queue) +# Pack a parameter study of 3 instances for both x and y. +# We are assuming these are distinct parameter studies. +packx = pack_for_parameter_study(actx,"x",tuple([3]), x, x1, x2) +packy = pack_for_parameter_study(actx,"y",tuple([4]), y, y1, y2, y3) +output_x = unpack_parameter_study(packx, "x") + +print(packx) breakpoint() -x + y +def rhs(param1, param2): + return param1 + param2 + +compiled_rhs = actx.compile(rhs) # Build the function caller + +# Builds a trace for a single instance of evaluating the RHS and then converts it to +# a program which takes our multiple instances of `x` and `y`. +output = compiled_rhs(packx, packy) + +assert output.shape == (3,4,15,5) # Distinct parameter studies. + +output_x = unpack_parameter_study(output, "x") +output_y = unpack_parameter_study(output, "y") +assert len(output_x) == 3 +assert output_x[0].shape == (4,15,5) +assert len(output_y) == 4 +assert output_y[0].shape == (3,15,5) From 48a0155d04e9809f633dcde0b3cbbd26db64d4e8 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 10 Jul 2024 13:32:23 -0500 Subject: [PATCH 04/25] Update terminology. --- arraycontext/parameter_study/__init__.py | 9 +++------ examples/{uncertain_prop.py => parameter_study.py} | 0 2 files changed, 3 insertions(+), 6 deletions(-) rename examples/{uncertain_prop.py => parameter_study.py} (100%) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 1c366cf9..d770c1f4 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -283,7 +283,6 @@ def pack_for_parameter_study(actx: ArrayContext, yourvarname: str, orig_shape = args[0].shape out = actx.np.stack(args) outshape = tuple([newshape] + [val for val in orig_shape]) - breakpoint() if len(newshape) > 1: # Reshape the object @@ -307,13 +306,11 @@ def unpack_parameter_study(data: ArrayT, varname: str) -> Sequence[ArrayT]: if axis_tags: # Now we need to split this data. breakpoint() - for j in range(len(data.axis[i])): + for j in range(data.shape[i]): the_slice = [slice(None)] * ndim the_slice[i] = j - if i in out.keys(): - out[i].append(data[the_slice]) - else: - out[i] = data[the_slice] + the_slice = tuple(the_slice) + out[tuple([i,j])] = data[the_slice] #yield data[the_slice] return out diff --git a/examples/uncertain_prop.py b/examples/parameter_study.py similarity index 100% rename from examples/uncertain_prop.py rename to examples/parameter_study.py From 4106d3202082d423ca499d2a2e3c9bb8b8c44ba7 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 10 Jul 2024 14:03:07 -0500 Subject: [PATCH 05/25] Update interface. --- arraycontext/parameter_study/__init__.py | 26 ++++++++++++++++-------- examples/parameter_study.py | 12 +++++------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index d770c1f4..9775e6f5 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -114,6 +114,9 @@ def transform_dag(self, ary): out = super().transform_dag(ary) return out + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + return ParamStudyLazyPyOpenCLFunctionCaller(self, f) + # }}} @@ -153,10 +156,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: for arg_id in arg_id_to_arg} output_template = self.f( - *[_get_f_placeholder_args(arg, iarg, + *[_get_f_placeholder_args_for_param_study(arg, iarg, input_id_to_name_in_program, self.actx) for iarg, arg in enumerate(args)], - **{kw: _get_f_placeholder_args(arg, kw, + **{kw: _get_f_placeholder_args_for_param_study(arg, kw, input_id_to_name_in_program, self.actx) for kw, arg in kwargs.items()}) @@ -227,14 +230,17 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): raise NotImplementedError(type(ary)) -def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): +def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): """ - Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the - placeholder version of an argument to - :attr:`BaseLazilyCompilingFunctionCaller.f`. + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. + Returns the placeholder version of an argument to + :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. Note this will modify the + shape of the placeholder to remove any parameter study axes until the trace + can be completed. """ if np.isscalar(arg): name = arg_id_to_name[(kw,)] + breakpoint() return pt.make_placeholder(name, (), np.dtype(type(arg))) elif isinstance(arg, pt.Array): name = arg_id_to_name[(kw,)] @@ -292,7 +298,7 @@ def pack_for_parameter_study(actx: ArrayContext, yourvarname: str, return out -def unpack_parameter_study(data: ArrayT, varname: str) -> Sequence[ArrayT]: +def unpack_parameter_study(data: ArrayT, varname: str) -> Mapping[int, ArrayT]: """ Split the data array along the axes which vary according to a ParameterStudyAxisTag whose variable name is varname. @@ -310,7 +316,11 @@ def unpack_parameter_study(data: ArrayT, varname: str) -> Sequence[ArrayT]: the_slice = [slice(None)] * ndim the_slice[i] = j the_slice = tuple(the_slice) - out[tuple([i,j])] = data[the_slice] + if i in out.keys(): + out[i].append(data[the_slice]) + else: + out[i] = [data[the_slice]] + #yield data[the_slice] return out diff --git a/examples/parameter_study.py b/examples/parameter_study.py index 4b1ed838..2d89797d 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -22,7 +22,7 @@ # Eq: z = x + y # Assumptions: x and y are independently uncertain. - +# Experimental setup base_shape = (15, 5) x = np.random.random(base_shape) x1 = np.random.random(base_shape) @@ -50,8 +50,6 @@ output_x = unpack_parameter_study(packx, "x") print(packx) -breakpoint() - def rhs(param1, param2): return param1 + param2 @@ -66,7 +64,9 @@ def rhs(param1, param2): output_x = unpack_parameter_study(output, "x") output_y = unpack_parameter_study(output, "y") -assert len(output_x) == 3 -assert output_x[0].shape == (4,15,5) -assert len(output_y) == 4 +assert len(output_x) == 1 # Number of parameter studies involving "x" +assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study +assert output_x[0][0].shape == (4,15,5) # All outputs across every other parameter study. +assert len(output_y) == 1 +assert len(output_y[0]) == 4 assert output_y[0].shape == (3,15,5) From 161445428b304bd4cd650bb01c01b1d8f2c78e1b Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 10 Jul 2024 14:28:41 -0500 Subject: [PATCH 06/25] Update interface. --- arraycontext/impl/pytato/__init__.py | 1 - arraycontext/parameter_study/__init__.py | 82 +++++++++--------------- 2 files changed, 30 insertions(+), 53 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 78caa5c6..5aefbeaa 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -10,7 +10,6 @@ Following :mod:`pytato`-based array context are provided: .. autoclass:: PytatoPyOpenCLArrayContext -.. autoclass:: PytatoPyOpenCLArrayContextUQ .. autoclass:: PytatoJAXArrayContext diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 9775e6f5..dd25f3ac 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -58,20 +58,22 @@ ) import numpy as np +import pytato as pt from pytools import memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag +from dataclasses import dataclass + from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ArrayT, ArrayContext from arraycontext.metadata import NameHint from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace -from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller - - -from dataclasses import dataclass +from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, + _get_arg_id_to_arg_and_arg_id_to_descr, + _to_input_for_compiled) if TYPE_CHECKING: import pyopencl as cl @@ -89,6 +91,11 @@ class ParameterStudyAxisTag(UniqueTag): """ A tag for acting on axes of arrays. + To enable multiple parameter studies on the same variable name + specify a different axis number and potentially a different size. + + Currently does not allow multiple variables of different names to be in + the same parameter study. """ user_variable_name: str axis_num: int @@ -109,11 +116,6 @@ class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): .. automethod:: compile """ - def transform_dag(self, ary): - # This going to be called before the compiler or freeze. - out = super().transform_dag(ary) - return out - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) @@ -131,12 +133,12 @@ class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller def __call__(self, *args: Any, **kwargs: Any) -> Any: """ - Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s + Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s function application on *args*. - Before applying :attr:`~BaseLazilyCompilingFunctionCaller.f`, it is compiled + Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, it is compiled to a :mod:`pytato` DAG that would apply - :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. + :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( @@ -147,6 +149,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: except KeyError: pass else: + # On a cache hit we do not need to modify anything. return compiled_f(arg_id_to_arg) dict_of_named_arrays = {} @@ -184,6 +187,11 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) + myMapper = ExpansionMapper(dict_of_named_arrays) # Get the dependencies + dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. + + # Use the normal compiler now. + compiled_func = self._dag_to_compiled_func( pt.make_dict_of_named_arrays(dict_of_named_arrays), input_id_to_name_in_program=input_id_to_name_in_program, @@ -193,50 +201,17 @@ def _as_dict_of_named_arrays(keys, ary): self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) - -def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): - """ - Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder` - in :meth:`LazilyCompilingFunctionCaller.__call__`. - - Preprocessing here refers to: - - - Metadata Inference that is supplied via *actx*\'s - :meth:`PytatoPyOpenCLArrayContext.transform_dag`. - """ - import pyopencl.array as cla - - from arraycontext.impl.pyopencl.taggable_cl_array import ( - TaggableCLArray, - to_tagged_cl_array, - ) - if isinstance(ary, pt.Array): - dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) - # Transform the DAG to give metadata inference a chance to do its job - return actx.transform_dag(dag)["_actx_out"].expr - elif isinstance(ary, TaggableCLArray): - return ary - elif isinstance(ary, cla.Array): - from warnings import warn - warn("Passing pyopencl.array.Array to a compiled callable" - " is deprecated and will stop working in 2023." - " Use `to_tagged_cl_array` to convert the array to" - " TaggableCLArray", DeprecationWarning, stacklevel=2) - - return to_tagged_cl_array(ary, - axes=None, - tags=frozenset()) - else: - raise NotImplementedError(type(ary)) - - def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): """ Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the placeholder version of an argument to - :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. Note this will modify the - shape of the placeholder to remove any parameter study axes until the trace + :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. + + Note this will modify the shape of the placeholder to + remove any parameter study axes until the trace can be completed. + + They will be added back after the trace is complete. """ if np.isscalar(arg): name = arg_id_to_name[(kw,)] @@ -298,10 +273,13 @@ def pack_for_parameter_study(actx: ArrayContext, yourvarname: str, return out -def unpack_parameter_study(data: ArrayT, varname: str) -> Mapping[int, ArrayT]: +def unpack_parameter_study(data: ArrayT, varname: str) -> Dict[int, Sequence[ArrayT]]: """ Split the data array along the axes which vary according to a ParameterStudyAxisTag whose variable name is varname. + + output[i] corresponds to the values associated with the ith parameter study that + uses the variable name :arg: `varname`. """ ndim = len(data.axes) From e99610bc33f0a8ad0b0c993f05b5980c76f46fae Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 15 Jul 2024 13:14:30 -0500 Subject: [PATCH 07/25] Another iteration on the mapper. --- arraycontext/__init__.py | 3 +- arraycontext/parameter_study/__init__.py | 218 +++---------- arraycontext/parameter_study/transform.py | 371 ++++++++++++++++++++++ examples/parameter_study.py | 94 +++--- 4 files changed, 460 insertions(+), 226 deletions(-) create mode 100644 arraycontext/parameter_study/transform.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 327fec66..f7855170 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -88,7 +88,8 @@ pytest_generate_tests_for_pyopencl_array_context, ) from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag -from .parameter_study import ParamStudyPytatoPyOpenCLArrayContext, pack_for_parameter_study, unpack_parameter_study +from .parameter_study import pack_for_parameter_study, unpack_parameter_study +from .parameter_study.transform import ParamStudyPytatoPyOpenCLArrayContext __all__ = ( "Array", diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index dd25f3ac..f1630674 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -4,19 +4,22 @@ A :mod:`pytato`-based array context defers the evaluation of an array until its frozen. The execution contexts for the evaluations are specific to an :class:`~arraycontext.ArrayContext` type. For ex. -:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to -JIT-compile and execute the array expressions. +:class:`~arraycontext.ParamStudyPytatoPyOpenCLArrayContext` +uses :mod:`pyopencl` to JIT-compile and execute the array expressions. Following :mod:`pytato`-based array context are provided: .. autoclass:: ParamStudyPytatoPyOpenCLArrayContext + +The compiled function is stored as. .. autoclass:: ParamStudyLazyPyOpenCLFunctionCaller -Compiling a Python callable (Internal) +Compiling a Python callable (Internal) for multiple distinct instances of +execution. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: arraycontext.impl.pytato.compile +.. automodule:: arraycontext.parameter_study """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -55,6 +58,7 @@ Type, Union, Sequence, + List, ) import numpy as np @@ -65,15 +69,22 @@ from dataclasses import dataclass -from arraycontext.container.traversal import rec_map_array_container, with_array_context +from arraycontext.container.traversal import (rec_map_array_container, + with_array_context, rec_keyed_map_array_container) + +from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.context import ArrayT, ArrayContext from arraycontext.metadata import NameHint -from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext -from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace +from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, _get_arg_id_to_arg_and_arg_id_to_descr, - _to_input_for_compiled) + _to_input_for_compiled, + _ary_container_key_stringifier) + +from arraycontext.parameter_study.transform import ExpansionMapper, ParameterStudyAxisTag + +# from arraycontext.parameter_study.transform import ExpansionMapper if TYPE_CHECKING: import pyopencl as cl @@ -87,160 +98,8 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class ParameterStudyAxisTag(UniqueTag): - """ - A tag for acting on axes of arrays. - To enable multiple parameter studies on the same variable name - specify a different axis number and potentially a different size. - - Currently does not allow multiple variables of different names to be in - the same parameter study. - """ - user_variable_name: str - axis_num: int - axis_size: int - -# {{{ ParamStudyPytatoPyOpenCLArrayContext - - -class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): - """ - A derived class for PytatoPyOpenCLArrayContext updated for the - purpose of enabling parameter studies and uncertainty quantification. - - .. automethod:: __init__ - - .. automethod:: transform_dag - .. automethod:: compile - """ - - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: - return ParamStudyLazyPyOpenCLFunctionCaller(self, f) - - -# }}} - - -class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): - """ - Record a side-effect-free callable :attr:`f` which is initially designed for - to be called multiple times with different data. This class will update the - signature to allow :attr:`f` to be called once with the data for multiple - instances. - """ - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - """ - Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s - function application on *args*. - - Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, it is compiled - to a :mod:`pytato` DAG that would apply - :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` with *args* in a lazy-sense. - The intermediary pytato DAG for *args* is memoized in *self*. - """ - arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( - args, kwargs) - - try: - compiled_f = self.program_cache[arg_id_to_descr] - except KeyError: - pass - else: - # On a cache hit we do not need to modify anything. - return compiled_f(arg_id_to_arg) - - dict_of_named_arrays = {} - output_id_to_name_in_program = {} - input_id_to_name_in_program = { - arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" - for arg_id in arg_id_to_arg} - - output_template = self.f( - *[_get_f_placeholder_args_for_param_study(arg, iarg, - input_id_to_name_in_program, self.actx) - for iarg, arg in enumerate(args)], - **{kw: _get_f_placeholder_args_for_param_study(arg, kw, - input_id_to_name_in_program, - self.actx) - for kw, arg in kwargs.items()}) - - self.actx._compile_trace_callback(self.f, "post_trace", output_template) - - if (not (is_array_container_type(output_template.__class__) - or isinstance(output_template, pt.Array))): - # TODO: We could possibly just short-circuit this interface if the - # returned type is a scalar. Not sure if it's worth it though. - raise NotImplementedError( - f"Function '{self.f.__name__}' to be compiled " - "did not return an array container or pt.Array," - f" but an instance of '{output_template.__class__}' instead.") - - def _as_dict_of_named_arrays(keys, ary): - name = "_pt_out_" + _ary_container_key_stringifier(keys) - output_id_to_name_in_program[keys] = name - dict_of_named_arrays[name] = ary - return ary - - rec_keyed_map_array_container(_as_dict_of_named_arrays, - output_template) - - myMapper = ExpansionMapper(dict_of_named_arrays) # Get the dependencies - dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. - - # Use the normal compiler now. - - compiled_func = self._dag_to_compiled_func( - pt.make_dict_of_named_arrays(dict_of_named_arrays), - input_id_to_name_in_program=input_id_to_name_in_program, - output_id_to_name_in_program=output_id_to_name_in_program, - output_template=output_template) - - self.program_cache[arg_id_to_descr] = compiled_func - return compiled_func(arg_id_to_arg) - -def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): - """ - Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. - Returns the placeholder version of an argument to - :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. - - Note this will modify the shape of the placeholder to - remove any parameter study axes until the trace - can be completed. - - They will be added back after the trace is complete. - """ - if np.isscalar(arg): - name = arg_id_to_name[(kw,)] - breakpoint() - return pt.make_placeholder(name, (), np.dtype(type(arg))) - elif isinstance(arg, pt.Array): - name = arg_id_to_name[(kw,)] - # Transform the DAG to give metadata inference a chance to do its job - arg = _to_input_for_compiled(arg, actx) - return pt.make_placeholder(name, arg.shape, arg.dtype, - axes=arg.axes, - tags=arg.tags) - elif is_array_container_type(arg.__class__): - def _rec_to_placeholder(keys, ary): - name = arg_id_to_name[(kw, *keys)] - # Transform the DAG to give metadata inference a chance to do its job - ary = _to_input_for_compiled(ary, actx) - return pt.make_placeholder(name, - ary.shape, - ary.dtype, - axes=ary.axes, - tags=ary.tags) - - return rec_keyed_map_array_container(_rec_to_placeholder, arg) - else: - raise NotImplementedError(type(arg)) - - -def pack_for_parameter_study(actx: ArrayContext, yourvarname: str, +def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyAxisTag, newshape: Tuple[int, ...], *args: ArrayT) -> ArrayT: """ @@ -248,57 +107,54 @@ def pack_for_parameter_study(actx: ArrayContext, yourvarname: str, to be packed for a parameter study or uncertainty quantification. Args needs to be in the format - ["v", v0, v1, v2, ..., vN, "w", w0, w1, w2, ..., wM, \dots] - - where "v" and "w" would be the variable names in your program. - If you want to include a constant just pass the var name and then - the value in the next argument. - - Returns a dictionary of {var name: stacked array} + [v0, v1, v2, ..., vN] where N is the total number of instances you want to + try. Note these may be across multiple parameter studies on the same inputs. """ assert len(args) > 0 assert len(args) == np.prod(newshape) - out = {} orig_shape = args[0].shape out = actx.np.stack(args) - outshape = tuple([newshape] + [val for val in orig_shape]) + outshape = tuple([newshape] + list(orig_shape)) if len(newshape) > 1: # Reshape the object out = out.reshape(outshape) for i in range(len(newshape)): - out = out.with_tagged_axis(i, [ParameterStudyAxisTag(yourvarname, i, newshape[i])]) + out = out.with_tagged_axis(i, [study_name_tag(i, newshape[i])]) return out -def unpack_parameter_study(data: ArrayT, varname: str) -> Dict[int, Sequence[ArrayT]]: +def unpack_parameter_study(data: ArrayT, + study_name_tag: ParameterStudyAxisTag) -> Dict[int, + List[ArrayT]]: """ Split the data array along the axes which vary according to a ParameterStudyAxisTag - whose variable name is varname. + whose name tag is an instance study_name_tag. output[i] corresponds to the values associated with the ith parameter study that - uses the variable name :arg: `varname`. + uses the variable name :arg: `study_name_tag`. """ - ndim = len(data.axes) - out = {} + ndim: int = len(data.axes) + out: Dict[int, List[ArrayT]] = {} for i in range(ndim): - axis_tags = data.axes[i].tags_of_type(ParameterStudyAxisTag) + axis_tags = data.axes[i].tags_of_type(study_name_tag) if axis_tags: # Now we need to split this data. breakpoint() for j in range(data.shape[i]): - the_slice = [slice(None)] * ndim - the_slice[i] = j - the_slice = tuple(the_slice) + tmp: List[slice] = [slice(None)] * ndim + tmp[i] = j + the_slice: Tuple[slice] = tuple(tmp) + # Needs to be a tuple of slices not list of slices. if i in out.keys(): out[i].append(data[the_slice]) else: out[i] = [data[the_slice]] - #yield data[the_slice] + # yield data[the_slice] return out diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py new file mode 100644 index 00000000..d15f0568 --- /dev/null +++ b/arraycontext/parameter_study/transform.py @@ -0,0 +1,371 @@ +""" +.. currentmodule:: arraycontext + +Compiling a Python callable (Internal) for multiple distinct instances of +execution. +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: arraycontext.parameter_study +""" +__copyright__ = """ +Copyright (C) 2020-1 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import abc +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + FrozenSet, + Optional, + Tuple, + Type, + Union, + Sequence, + List, + Mapping, +) + +import numpy as np +import pytato as pt +from immutabledict import immutabledict + +from pytools import memoize_method +from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag + +from dataclasses import dataclass + +from arraycontext.container.traversal import (rec_map_array_container, + with_array_context, rec_keyed_map_array_container) + +from arraycontext.container import ArrayContainer, is_array_container_type + +from arraycontext.context import ArrayT, ArrayContext +from arraycontext.metadata import NameHint +from arraycontext import PytatoPyOpenCLArrayContext +from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, + _get_arg_id_to_arg_and_arg_id_to_descr, + _to_input_for_compiled, + _ary_container_key_stringifier) + + +from pytato.transform import CopyMapper + +from pytato.array import ( + Array, IndexLambda, Placeholder, Stack, Roll, AxisPermutation, + DataWrapper, SizeParam, DictOfNamedArrays, AbstractResultWithNamedArrays, + Reshape, Concatenate, NamedArray, IndexRemappingBase, Einsum, + InputArgumentBase, AdvancedIndexInNoncontiguousAxes, IndexBase, DataInterface, + Axis) + +from pytato.utils import broadcast_binary_op + +@dataclass(frozen=True) +class ParameterStudyAxisTag(UniqueTag): + """ + A tag for acting on axes of arrays. + To enable multiple parameter studies on the same variable name + specify a different axis number and potentially a different size. + + Currently does not allow multiple variables of different names to be in + the same parameter study. + """ + #user_param_study_tag: Tag + axis_num: int + axis_size: int + +class ExpansionMapper(CopyMapper): + + #def __init__(self, dependency_map: Dict[Array,Tag]): + # super().__init__() + # self.depends = dependency_map + def __init__(self, actual_input_shapes: Mapping[str, Tuple[int,...]], + actual_input_axes: Mapping[str, FrozenSet[Axis]]): + super().__init__() + self.actual_input_shapes = actual_input_shapes + self.actual_input_axes = actual_input_axes + + + def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, + new_expr: Array) -> Tuple[Optional[Tuple[int]], + Optional[Tuple[Axis]]]: + shape_to_prepend: Tuple[int] = tuple([]) + new_axes: Tuple[Axis] = tuple([]) + if curr_expr.shape == new_expr.shape: + return shape_to_prepend, new_axes + + # Now we may need to change. + changed = False + for i in range(len(new_expr.axes)): + axis_tags = list(new_expr.axes[i].tags) + for j, tag in enumerate(axis_tags): + # Should be relatively few tags on each axis $O(1)$. + if isinstance(tag, ParameterStudyAxisTag): + new_axes = new_axes + (new_expr.axes[i],) + shape_to_prepend = shape_to_prepend + (new_expr.shape[i],) + return shape_to_prepend, new_axes + + + def map_stack(self, expr: Stack) -> Array: + return super().map_stack(expr) + + def map_concatenate(self, expr: Concatenate) -> Array: + return super().map_concatenate(expr) + + def map_roll(self, expr: Roll) -> Array: + new_array = self.rec(expr.array) + prepend_shape, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + new_array) + return Roll(array=new_array, + shift=expr.shift, + axis=expr.axis + len(new_axes), + axes=new_axes + expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_axis_permutation(self, expr: AxisPermutation) -> Array: + new_array = self.rec(expr.array) + prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + new_array) + breakpoint() + axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i + in range(len(expr.axis_permutation))]) + # Include the axes we are adding to the system. + axis_permute = tuple([i for i in range(len(prepend_shape))]) + axis_permute + + + return AxisPermutation(array=new_array, + axis_permutation=axis_permute, + axes=new_axes + expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def _map_index_base(self, expr: IndexBase) -> Array: + new_array = self.rec(expr.array) + prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + new_array) + return type(expr)(new_array, + indices=self.rec_idx_or_size_tuple(expr.indices), + # May need to modify indices + axes=new_axes + expr.axes, + tags=expr.tags, + non_equality_tags = expr.non_equality_tags) + + def map_reshape(self, expr: Reshape) -> Array: + new_array = self.rec(expr.array) + prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + new_array) + return Reshape(new_array, + newshape = self.rec_idx_or_size_tuple(prepend_shape + expr.newshape), + order=expr.order, + axes=new_axes + expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_placeholder(self, expr: Placeholder) -> Array: + # This is where we could introduce extra axes. + breakpoint() + correct_shape = expr.shape + correct_axes = expr.axes + if expr.name in self.actual_input_shapes.keys(): + # We may need to update the size. + if expr.shape != self.actual_input_shapes[expr.name]: + correct_shape = self.actual_input_shapes[expr.name] + correct_axes = self.actual_input_axes[expr.name] + return Placeholder(name=expr.name, + shape=self.rec_idx_or_size_tuple(correct_shape), + dtype=expr.dtype, + axes=correct_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: + # TODO: Fix + return super().map_index_lambda(expr) + +# {{{ ParamStudyPytatoPyOpenCLArrayContext + + +class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): + """ + A derived class for PytatoPyOpenCLArrayContext updated for the + purpose of enabling parameter studies and uncertainty quantification. + + .. automethod:: __init__ + + .. automethod:: transform_dag + + .. automethod:: compile + """ + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + return ParamStudyLazyPyOpenCLFunctionCaller(self, f) + + +# }}} + + + +class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): + """ + Record a side-effect-free callable :attr:`f` which is initially designed for + to be called multiple times with different data. This class will update the + signature to allow :attr:`f` to be called once with the data for multiple + instances. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s + function application on *args*. + + Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, it is compiled + to a :mod:`pytato` DAG that would apply + :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` with *args* in a lazy-sense. + The intermediary pytato DAG for *args* is memoized in *self*. + """ + arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( + args, kwargs) + + try: + compiled_f = self.program_cache[arg_id_to_descr] + except KeyError: + pass + else: + # On a cache hit we do not need to modify anything. + return compiled_f(arg_id_to_arg) + + dict_of_named_arrays = {} + output_id_to_name_in_program = {} + input_id_to_name_in_program = { + arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + for arg_id in arg_id_to_arg} + + output_template = self.f( + *[_get_f_placeholder_args_for_param_study(arg, iarg, + input_id_to_name_in_program, self.actx) + for iarg, arg in enumerate(args)], + **{kw: _get_f_placeholder_args_for_param_study(arg, kw, + input_id_to_name_in_program, + self.actx) + for kw, arg in kwargs.items()}) + + self.actx._compile_trace_callback(self.f, "post_trace", output_template) + + if (not (is_array_container_type(output_template.__class__) + or isinstance(output_template, pt.Array))): + # TODO: We could possibly just short-circuit this interface if the + # returned type is a scalar. Not sure if it's worth it though. + raise NotImplementedError( + f"Function '{self.f.__name__}' to be compiled " + "did not return an array container or pt.Array," + f" but an instance of '{output_template.__class__}' instead.") + + def _as_dict_of_named_arrays(keys, ary): + name = "_pt_out_" + _ary_container_key_stringifier(keys) + output_id_to_name_in_program[keys] = name + dict_of_named_arrays[name] = ary + return ary + + rec_keyed_map_array_container(_as_dict_of_named_arrays, + output_template) + + breakpoint() + input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()} + input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} + myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies + dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) + + dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. + breakpoint() + + # Use the normal compiler now. + + compiled_func = self._dag_to_compiled_func( + pt.make_dict_of_named_arrays(dict_of_named_arrays), + input_id_to_name_in_program=input_id_to_name_in_program, + output_id_to_name_in_program=output_id_to_name_in_program, + output_template=output_template) + + self.program_cache[arg_id_to_descr] = compiled_func + return compiled_func(arg_id_to_arg) + + +def _cut_if_in_param_study(name, arg) -> Array: + """ + Helper to split a place holder into the base instance shape + if it is tagged with a `ParameterStudyAxisTag` + to ensure the survival of the information those tags will be converted + to temporary Array Tags of the same type. The placeholder will not + have the axes marked with a `ParameterStudyAxisTag` tag. + """ + ndim: int = len(arg.shape) + newshape = [] + update_tags: set = set() + update_axes = [] + for i in range(ndim): + axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) + if axis_tags: + # We need to remove those tags. + update_tags.add(axis_tags) + else: + update_axes.append(arg.axes[i]) + newshape.append(arg.shape[i]) + + update_tags.update(arg.tags) + update_axes = tuple(update_axes) + update_tags = list(update_tags)[0] # Return just the frozenset. + return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, + tags=update_tags) + +def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): + """ + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. + Returns the placeholder version of an argument to + :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. + + Note this will modify the shape of the placeholder to + remove any parameter study axes until the trace + can be completed. + + They will be added back after the trace is complete. + """ + if np.isscalar(arg): + name = arg_id_to_name[(kw,)] + return pt.make_placeholder(name, (), np.dtype(type(arg))) + elif isinstance(arg, pt.Array): + name = arg_id_to_name[(kw,)] + # Transform the DAG to give metadata inference a chance to do its job + arg = _to_input_for_compiled(arg, actx) + return _cut_if_in_param_study(name, arg) + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + name = arg_id_to_name[(kw, *keys)] + # Transform the DAG to give metadata inference a chance to do its job + ary = _to_input_for_compiled(ary, actx) + return _cut_if_in_param_study(name, ary) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) diff --git a/examples/parameter_study.py b/examples/parameter_study.py index 2d89797d..cf790149 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -1,72 +1,78 @@ -import arraycontext from dataclasses import dataclass -import numpy as np # for the data types +import numpy as np # for the data types from pytools.tag import Tag -from arraycontext.impl.pytato.__init__ import (PytatoPyOpenCLArrayContext) - -# The goal of this file is to propagate the uncertainty in an array to the output. - +from arraycontext.parameter_study import (pack_for_parameter_study, + unpack_parameter_study) +from arraycontext.parameter_study.transform import (ParamStudyPytatoPyOpenCLArrayContext, + ParameterStudyAxisTag) -my_context = arraycontext.impl.pytato.PytatoPyOpenCLArrayContext -a = my_context.zeros(my_context, shape=(5,5), dtype=np.int32) + 2 +import pyopencl as cl +import pytato as pt -b = my_context.zeros(my_context, (5,5), np.int32) + 15 +ctx = cl.create_some_context(interactive=False) +queue = cl.CommandQueue(ctx) +actx = ParamStudyPytatoPyOpenCLArrayContext(queue) -print(a) -print("========================================================") -print(b) -print("========================================================") # Eq: z = x + y -# Assumptions: x and y are independently uncertain. +# Assumptions: x and y are undergoing independent parameter studies. +base_shape = (15, 5) +def rhs(param1, param2): + return pt.transpose(param1) + return pt.roll(param1, shift=3, axis=1) + return param1.reshape(np.prod(base_shape)) + return param1 + param2 + # Experimental setup -base_shape = (15, 5) -x = np.random.random(base_shape) -x1 = np.random.random(base_shape) -x2 = np.random.random(base_shape) +seed = 12345 +rng = np.random.default_rng(seed) +x = actx.from_numpy(rng.random(base_shape)) +x1 = actx.from_numpy(rng.random(base_shape)) +x2 = actx.from_numpy(rng.random(base_shape)) -y = np.random.random(base_shape) -y1 = np.random.random(base_shape) -y2 = np.random.random(base_shape) -y3 = np.random.random(base_shape) +y = actx.from_numpy(rng.random(base_shape)) +y1 = actx.from_numpy(rng.random(base_shape)) +y2 = actx.from_numpy(rng.random(base_shape)) +y3 = actx.from_numpy(rng.random(base_shape)) -from arraycontext.parameter_study import (pack_for_parameter_study, - ParamStudyPytatoPyOpenCLArrayContext, unpack_parameter_study) -import pyopencl as cl +@dataclass(frozen=True) +class ParameterStudyForX(ParameterStudyAxisTag): + pass -ctx = cl.create_some_context(interactive=False) -queue = cl.CommandQueue(ctx) -actx = ParamStudyPytatoPyOpenCLArrayContext(queue) +@dataclass(frozen=True) +class ParameterStudyForY(ParameterStudyAxisTag): + pass + # Pack a parameter study of 3 instances for both x and y. # We are assuming these are distinct parameter studies. -packx = pack_for_parameter_study(actx,"x",tuple([3]), x, x1, x2) -packy = pack_for_parameter_study(actx,"y",tuple([4]), y, y1, y2, y3) -output_x = unpack_parameter_study(packx, "x") +packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2) +packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3) +output_x = unpack_parameter_study(packx, ParameterStudyForX) print(packx) -def rhs(param1, param2): - return param1 + param2 - -compiled_rhs = actx.compile(rhs) # Build the function caller +compiled_rhs = actx.compile(rhs) # Build the function caller -# Builds a trace for a single instance of evaluating the RHS and then converts it to -# a program which takes our multiple instances of `x` and `y`. +# Builds a trace for a single instance of evaluating the RHS and +# then converts it to a program which takes our multiple instances +# of `x` and `y`. +breakpoint() output = compiled_rhs(packx, packy) -assert output.shape == (3,4,15,5) # Distinct parameter studies. +assert output.shape == (3, 4, 15, 5) # Distinct parameter studies. -output_x = unpack_parameter_study(output, "x") -output_y = unpack_parameter_study(output, "y") -assert len(output_x) == 1 # Number of parameter studies involving "x" -assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study -assert output_x[0][0].shape == (4,15,5) # All outputs across every other parameter study. +output_x = unpack_parameter_study(output, ParameterStudyForX) +output_y = unpack_parameter_study(output, ParameterStudyForY) +assert len(output_x) == 1 # Number of parameter studies involving "x" +assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study +# All outputs across every other parameter study. +assert output_x[0][0].shape == (4, 15, 5) assert len(output_y) == 1 assert len(output_y[0]) == 4 -assert output_y[0].shape == (3,15,5) +assert output_y[0][0].shape == (3, 15, 5) From 95e74ef073b9b63360c0ea8b1cabcf38f014fad1 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 15 Jul 2024 17:46:29 -0500 Subject: [PATCH 08/25] Add some test cases and a start on the index lambda transform. --- arraycontext/parameter_study/transform.py | 48 ++++- examples/parameter_study.py | 1 + test/test_pytato_parameter_study.py | 215 ++++++++++++++++++++++ 3 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 test/test_pytato_parameter_study.py diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index d15f0568..4ff17e23 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -186,7 +186,6 @@ def map_reshape(self, expr: Reshape) -> Array: def map_placeholder(self, expr: Placeholder) -> Array: # This is where we could introduce extra axes. - breakpoint() correct_shape = expr.shape correct_axes = expr.axes if expr.name in self.actual_input_shapes.keys(): @@ -202,6 +201,51 @@ def map_placeholder(self, expr: Placeholder) -> Array: non_equality_tags=expr.non_equality_tags) def map_index_lambda(self, expr: IndexLambda) -> Array: # TODO: Fix + # Update bindings first. + new_bindings: Mapping[str, Array] = { name: self.rec(bnd) + for name, bnd in sorted(expr.bindings.items())} + + # Determine the new parameter studies that are being conducted. + from pytools import unique + from pytools.obj_array import flat_obj_array + + all_axis_tags: Set[Tag] = set() + for name, bnd in sorted(new_bindings.items()): + axis_tags_for_bnd: Set[Tag] = set() + for i in range(len(bnd.axes)): + axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) + breakpoint() + all_axis_tags = all_axis_tags.union(axis_tags_for_bnd) + + # Freeze the set now. + all_axis_tags = frozenset(all_axis_tags) + + + breakpoint() + active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) + axes: Optional[Tuple[Axis]] = tuple([]) + study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {} + + + count = 0 + new_shape: List[int] = [0 for i in + range(len(active_studies) + len(expr.shape))] + new_axes: List[Axis] = [Axis() for i in range(len(new_shape))] + + for study in active_studies: + if isinstance(study, ParameterStudyAxisTag): + # Just defensive programming + study_to_axis_number[type(study)] = count + new_shape[count] = study.axis_size # We are recording the size of each parameter study. + new_axes[count] = new_axes[count].tagged([study]) + count += 1 + + for i in range(len(expr.shape)): + new_shape[count] = expr.shape[i] + count += 1 + new_shape: Tuple[int] = tuple(new_shape) + + breakpoint() return super().map_index_lambda(expr) # {{{ ParamStudyPytatoPyOpenCLArrayContext @@ -291,14 +335,12 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - breakpoint() input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()} input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. - breakpoint() # Use the normal compiler now. diff --git a/examples/parameter_study.py b/examples/parameter_study.py index cf790149..ef9087bf 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -20,6 +20,7 @@ # Assumptions: x and y are undergoing independent parameter studies. base_shape = (15, 5) def rhs(param1, param2): + return param1 + param2 return pt.transpose(param1) return pt.roll(param1, shift=3, axis=1) return param1.reshape(np.prod(base_shape)) diff --git a/test/test_pytato_parameter_study.py b/test/test_pytato_parameter_study.py new file mode 100644 index 00000000..c4f9a4ed --- /dev/null +++ b/test/test_pytato_parameter_study.py @@ -0,0 +1,215 @@ +""" PytatoArrayContext specific tests on the Parameter Study Module""" + +__copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import logging + +import pytest + +from pytools.tag import Tag + +from arraycontext import ( + PytatoPyOpenCLArrayContext, + pytest_generate_tests_for_array_contexts, +) +from arraycontext.parameter_study.transform import ( + ParamStudyPytatoPyOpenCLArrayContext, + ParameterStudyAxisTag +) +from arraycontext.parameter_study import ( + pack_for_parameter_study, + unpack_parameter_study, +) + +from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory + + +logger = logging.getLogger(__name__) + + +# {{{ pytato-array context fixture + +class _PytatoPyOpenCLArrayContextForTests(ParamStudyPytatoPyOpenCLArrayContext): + """Like :class:`PytatoPyOpenCLArrayContext`, but applies no program + transformations whatsoever. Only to be used for testing internal to + :mod:`arraycontext`. + """ + + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytatoPyOpenCLArrayContextForTestsFactory( + _PytestPytatoPyOpenCLArrayContextFactory): + actx_class = _PytatoPyOpenCLArrayContextForTests + + +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + _PytatoPyOpenCLArrayContextForTestsFactory, + ]) + +# }}} + + +# {{{ dummy tag types + +class FooTag(Tag): + """ + Foo + """ + + +class BarTag(Tag): + """ + Bar + """ + + +class BazTag(Tag): + """ + Baz + """ + +class ParamStudy1(ParameterStudyAxisTag): + """ + 1st parameter study. + """ + +class ParamStudy2(ParameterStudyAxisTag): + """ + 2bd parameter study. + """ +# }}} + + +# {{{ Expansion Mapper specific tests. + +def test_pack_for_parameter_study(actx_factory): + + actx = actx_factory() + + from arraycontext.impl.pytato import _BasePytatoArrayContext + if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): + pytest.skip("only parameter study array contexts are supported") + + import numpy as np + seed = 12345 + rng = np.random.default_rng(seed) + + base_shape = (15, 5) + x0 = actx.from_numpy(rng.random(base_shape)) + x1 = actx.from_numpy(rng.random(base_shape)) + x2 = actx.from_numpy(rng.random(base_shape)) + x3 = actx.from_numpy(rng.random(base_shape)) + + + y0 = actx.from_numpy(rng.random(base_shape)) + y1 = actx.from_numpy(rng.random(base_shape)) + y2 = actx.from_numpy(rng.random(base_shape)) + y3 = actx.from_numpy(rng.random(base_shape)) + y4 = actx.from_numpy(rng.random(base_shape)) + + def rhs(a,b): + return a + b + + pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) + assert pack_x.shape == (4,15,5) + + pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0,y1, y2,y3,y4) + assert pack_y.shape == (5,15,5) + + for i in range(3): + axis_tags = pack_x.axes[i].tags_of_type(ParamStudy1) + second_tags = pack_x.axes[i].tags_of_type(ParamStudy2) + if i == 0: + assert axis_tags + else: + assert not axis_tags + assert not second_tags + +def test_unpack_parameter_study(actx_factory): + + actx = actx_factory() + + from arraycontext.impl.pytato import _BasePytatoArrayContext + if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): + pytest.skip("only parameter study array contexts are supported") + + import numpy as np + seed = 12345 + rng = np.random.default_rng(seed) + + base_shape = (15, 5) + x0 = actx.from_numpy(rng.random(base_shape)) + x1 = actx.from_numpy(rng.random(base_shape)) + x2 = actx.from_numpy(rng.random(base_shape)) + x3 = actx.from_numpy(rng.random(base_shape)) + + + y0 = actx.from_numpy(rng.random(base_shape)) + y1 = actx.from_numpy(rng.random(base_shape)) + y2 = actx.from_numpy(rng.random(base_shape)) + y3 = actx.from_numpy(rng.random(base_shape)) + y4 = actx.from_numpy(rng.random(base_shape)) + + def rhs(a,b): + return a + b + + pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) + assert pack_x.shape == (4,15,5) + + pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0,y1, y2,y3,y4) + assert pack_y.shape == (5,15,5) + + compiled_rhs = actx.compile(rhs) + + output = compiled_rhs(pack_x, pack_y) + + assert output.shape(4,5,15,5) + + output_x = unpack_parameter_study(output, ParamStudy1) + assert len(output_x) == 1 # Only 1 study associated with this variable. + assert len(output_x[0]) == 4 # 4 inputs for the parameter study. + for i in range(len(output_x[0])): + assert output_x[0][i].shape == (5, 15, 5) + + + output_y = unpack_parameter_study(output, ParamStudy2) + assert len(output_y) == 1 # Only 1 study associated with this variable. + assert len(output_y[0]) == 5 # 5 inputs for the parameter study. + for i in range(len(output_y[0])): + assert output_y[0][i].shape == (4, 15, 5) + + +# }}} + + + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + pytest.main([__file__]) + +# vim: fdm=marker From 804ed42cd1a69558bb6c0c2ae0459e3a8d7c18d8 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 16 Jul 2024 16:26:17 -0500 Subject: [PATCH 09/25] Update the Expansion Mapper for index lambda and move the packer to pack the items in the later axes as opposed to prepending the new axes. --- arraycontext/impl/pytato/__init__.py | 5 +- arraycontext/impl/pytato/compile.py | 5 +- arraycontext/parameter_study/__init__.py | 27 +++--- arraycontext/parameter_study/transform.py | 100 +++++++++++++++------- examples/parameter_study.py | 35 +++----- 5 files changed, 105 insertions(+), 67 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 5aefbeaa..fac98e8f 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -563,7 +563,10 @@ def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: evt, out_dict = pt_prg(self.queue, allocator=self.allocator, **bound_arguments) - evt.wait() + if isinstance(evt, list): + [_evt.wait() for _evt in evt] + else: + evt.wait() assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 key_to_frozen_subary = { diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 523ef3a4..4d3a14d8 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -646,7 +646,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: # FIXME Kernels (for now) allocate tons of memory in temporaries. If we # race too far ahead with enqueuing, there is a distinct risk of # running out of memory. This mitigates that risk a bit, for now. - evt.wait() + if isinstance(evt, list): + [_evt.wait() for _evt in evt] + else: + evt.wait() def to_output_template(keys, _): name_in_program = self.output_id_to_name_in_program[keys] diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index f1630674..46dcfdf4 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -115,14 +115,15 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA assert len(args) == np.prod(newshape) orig_shape = args[0].shape - out = actx.np.stack(args) - outshape = tuple([newshape] + list(orig_shape)) - - if len(newshape) > 1: - # Reshape the object - out = out.reshape(outshape) - for i in range(len(newshape)): - out = out.with_tagged_axis(i, [study_name_tag(i, newshape[i])]) + out = actx.np.stack(args, axis=args[0].ndim) + outshape = tuple(list(orig_shape) + [newshape] ) + + #if len(newshape) > 1: + # # Reshape the object + # out = out.reshape(outshape) + + for i in range(len(orig_shape), len(outshape)): + out = out.with_tagged_axis(i, [study_name_tag(i - len(orig_shape), newshape[i-len(orig_shape)])]) return out @@ -140,6 +141,7 @@ def unpack_parameter_study(data: ArrayT, ndim: int = len(data.axes) out: Dict[int, List[ArrayT]] = {} + study_count = 0 for i in range(ndim): axis_tags = data.axes[i].tags_of_type(study_name_tag) if axis_tags: @@ -150,11 +152,12 @@ def unpack_parameter_study(data: ArrayT, tmp[i] = j the_slice: Tuple[slice] = tuple(tmp) # Needs to be a tuple of slices not list of slices. - if i in out.keys(): - out[i].append(data[the_slice]) + if study_count in out.keys(): + out[study_count].append(data[the_slice]) else: - out[i] = [data[the_slice]] - + out[study_count] = [data[the_slice]] + if study_count in out.keys(): + study_count += 1 # yield data[the_slice] return out diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 4ff17e23..b95d08e4 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -50,8 +50,14 @@ import numpy as np import pytato as pt +import loopy as lp from immutabledict import immutabledict + +from pytato.scalar_expr import IdentityMapper +import pymbolic.primitives as prim + + from pytools import memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag @@ -129,6 +135,7 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A def map_stack(self, expr: Stack) -> Array: + # TODO: Fix return super().map_stack(expr) def map_concatenate(self, expr: Concatenate) -> Array: @@ -149,7 +156,6 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) - breakpoint() axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i in range(len(expr.axis_permutation))]) # Include the axes we are adding to the system. @@ -199,8 +205,8 @@ def map_placeholder(self, expr: Placeholder) -> Array: axes=correct_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: - # TODO: Fix # Update bindings first. new_bindings: Mapping[str, Array] = { name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} @@ -210,43 +216,77 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: from pytools.obj_array import flat_obj_array all_axis_tags: Set[Tag] = set() + studies_by_variable: Mapping[str, Mapping[Tag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() + studies_by_variable[name] = {} for i in range(len(bnd.axes)): axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) - breakpoint() + for tag in axis_tags_for_bnd: + studies_by_variable[name][tag] = 1 all_axis_tags = all_axis_tags.union(axis_tags_for_bnd) # Freeze the set now. all_axis_tags = frozenset(all_axis_tags) - - breakpoint() active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) axes: Optional[Tuple[Axis]] = tuple([]) study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {} - count = 0 - new_shape: List[int] = [0 for i in - range(len(active_studies) + len(expr.shape))] - new_axes: List[Axis] = [Axis() for i in range(len(new_shape))] + new_shape = expr.shape + new_axes = expr.axes for study in active_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming - study_to_axis_number[type(study)] = count - new_shape[count] = study.axis_size # We are recording the size of each parameter study. - new_axes[count] = new_axes[count].tagged([study]) - count += 1 + # The active studies are added to the end of the bindings. + study_to_axis_number[study] = len(new_shape) + new_shape = new_shape + (study.axis_size,) + new_axes = new_axes + (Axis(tags=frozenset((study,))),) + # This assumes that the axis only has 1 tag, + # because there should be no dependence across instances. + + # Now we need to update the expressions. + scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number) + + return IndexLambda(expr=scalar_expr, + bindings=type(expr.bindings)(new_bindings), + shape=new_shape, + var_to_reduction_descr=expr.var_to_reduction_descr, + dtype=expr.dtype, + axes=new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - for i in range(len(expr.shape)): - new_shape[count] = expr.shape[i] - count += 1 - new_shape: Tuple[int] = tuple(new_shape) - breakpoint() - return super().map_index_lambda(expr) +class ParamAxisExpander(IdentityMapper): + + def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, Mapping[ParameterStudyAxisTag, bool]], + study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): + # We know that we are not changing the variable that we are indexing into. + # This is stored in the aggregate member of the class Subscript. + + # We only need to modify the indexing which is stored in the index member. + name = expr.aggregate.name + if name in studies_by_variable.keys(): + # These are the single instance information. + index = self.rec(expr.index, studies_by_variable, + study_to_axis_number) + + new_vars: Tuple[prim.Variable] = tuple([]) + + for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): + if key in studies_by_variable[name]: + new_vars = new_vars + (prim.Variable(f"_{study_to_axis_number[key]}"),) + + if isinstance(index, tuple): + index = index + new_vars + else: + index = tuple(index) + new_vars + return type(expr)(aggregate=expr.aggregate, index=index) + return expr + # {{{ ParamStudyPytatoPyOpenCLArrayContext @@ -267,6 +307,10 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) + def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + # Update in a subclass if you want. + return t_unit + # }}} @@ -338,18 +382,21 @@ def _as_dict_of_named_arrays(keys, ary): input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()} input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies + breakpoint() + dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) + breakpoint() dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. # Use the normal compiler now. - - compiled_func = self._dag_to_compiled_func( - pt.make_dict_of_named_arrays(dict_of_named_arrays), + + compiled_func = self._dag_to_compiled_func(dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) + breakpoint() self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) @@ -364,20 +411,15 @@ def _cut_if_in_param_study(name, arg) -> Array: """ ndim: int = len(arg.shape) newshape = [] - update_tags: set = set() update_axes = [] for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) - if axis_tags: - # We need to remove those tags. - update_tags.add(axis_tags) - else: + if not axis_tags: update_axes.append(arg.axes[i]) newshape.append(arg.shape[i]) - update_tags.update(arg.tags) update_axes = tuple(update_axes) - update_tags = list(update_tags)[0] # Return just the frozenset. + update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) diff --git a/examples/parameter_study.py b/examples/parameter_study.py index ef9087bf..78adc2c5 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -15,19 +15,8 @@ queue = cl.CommandQueue(ctx) actx = ParamStudyPytatoPyOpenCLArrayContext(queue) - -# Eq: z = x + y -# Assumptions: x and y are undergoing independent parameter studies. -base_shape = (15, 5) -def rhs(param1, param2): - return param1 + param2 - return pt.transpose(param1) - return pt.roll(param1, shift=3, axis=1) - return param1.reshape(np.prod(base_shape)) - return param1 + param2 - - # Experimental setup +base_shape = (15, 5) seed = 12345 rng = np.random.default_rng(seed) x = actx.from_numpy(rng.random(base_shape)) @@ -39,41 +28,39 @@ def rhs(param1, param2): y2 = actx.from_numpy(rng.random(base_shape)) y3 = actx.from_numpy(rng.random(base_shape)) +# Eq: z = x + y +# Assumptions: x and y are undergoing independent parameter studies. +def rhs(param1, param2): + return param1 + param2 @dataclass(frozen=True) class ParameterStudyForX(ParameterStudyAxisTag): pass - @dataclass(frozen=True) class ParameterStudyForY(ParameterStudyAxisTag): pass - # Pack a parameter study of 3 instances for both x and y. -# We are assuming these are distinct parameter studies. packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2) packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3) -output_x = unpack_parameter_study(packx, ParameterStudyForX) - -print(packx) compiled_rhs = actx.compile(rhs) # Build the function caller # Builds a trace for a single instance of evaluating the RHS and -# then converts it to a program which takes our multiple instances -# of `x` and `y`. -breakpoint() +# then converts it to a program which takes our multiple instances of `x` and `y`. output = compiled_rhs(packx, packy) +breakpoint() +output_2 = compiled_rhs(x,y) -assert output.shape == (3, 4, 15, 5) # Distinct parameter studies. +assert output.shape == (15, 5, 3, 4) # Distinct parameter studies. output_x = unpack_parameter_study(output, ParameterStudyForX) output_y = unpack_parameter_study(output, ParameterStudyForY) assert len(output_x) == 1 # Number of parameter studies involving "x" assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study # All outputs across every other parameter study. -assert output_x[0][0].shape == (4, 15, 5) +assert output_x[0][0].shape == (15, 5, 4) assert len(output_y) == 1 assert len(output_y[0]) == 4 -assert output_y[0][0].shape == (3, 15, 5) +assert output_y[0][0].shape == (15, 5, 3) From 607db791fc7a036ecd1979f40a6a353952e89b0e Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 16 Jul 2024 18:25:01 -0500 Subject: [PATCH 10/25] Correct most of the type annotations. --- arraycontext/parameter_study/__init__.py | 35 ++++---- arraycontext/parameter_study/transform.py | 97 +++++++++++++---------- 2 files changed, 75 insertions(+), 57 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 46dcfdf4..cbe68cc6 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -59,11 +59,15 @@ Union, Sequence, List, + Iterable, + Mapping, ) import numpy as np import pytato as pt +from pytato.array import Array + from pytools import memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag @@ -74,7 +78,7 @@ from arraycontext.container import ArrayContainer, is_array_container_type -from arraycontext.context import ArrayT, ArrayContext +from arraycontext.context import ArrayContext from arraycontext.metadata import NameHint from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, @@ -99,9 +103,10 @@ logger = logging.getLogger(__name__) -def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyAxisTag, +def pack_for_parameter_study(actx: ArrayContext, + study_name_tag_type: Type[ParameterStudyAxisTag], newshape: Tuple[int, ...], - *args: ArrayT) -> ArrayT: + *args: Array) -> Array: """ Args is a list of variable names and the realized input data that needs to be packed for a parameter study or uncertainty quantification. @@ -115,7 +120,7 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA assert len(args) == np.prod(newshape) orig_shape = args[0].shape - out = actx.np.stack(args, axis=args[0].ndim) + out = actx.np.stack(args, axis=len(args[0].shape)) outshape = tuple(list(orig_shape) + [newshape] ) #if len(newshape) > 1: @@ -123,34 +128,34 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA # out = out.reshape(outshape) for i in range(len(orig_shape), len(outshape)): - out = out.with_tagged_axis(i, [study_name_tag(i - len(orig_shape), newshape[i-len(orig_shape)])]) + out = out.with_tagged_axis(i, [study_name_tag_type(i - len(orig_shape), newshape[i-len(orig_shape)])]) return out -def unpack_parameter_study(data: ArrayT, - study_name_tag: ParameterStudyAxisTag) -> Dict[int, - List[ArrayT]]: +def unpack_parameter_study(data: Array, + study_name_tag_type: Type[ParameterStudyAxisTag]) -> Mapping[int, + List[Array]]: """ Split the data array along the axes which vary according to a ParameterStudyAxisTag - whose name tag is an instance study_name_tag. + whose name tag is an instance study_name_tag_type. output[i] corresponds to the values associated with the ith parameter study that - uses the variable name :arg: `study_name_tag`. + uses the variable name :arg: `study_name_tag_type`. """ - ndim: int = len(data.axes) - out: Dict[int, List[ArrayT]] = {} + ndim: int = len(data.shape) + out: Dict[int, List[Array]] = {} study_count = 0 for i in range(ndim): - axis_tags = data.axes[i].tags_of_type(study_name_tag) + axis_tags = data.axes[i].tags_of_type(study_name_tag_type) if axis_tags: # Now we need to split this data. breakpoint() for j in range(data.shape[i]): - tmp: List[slice] = [slice(None)] * ndim + tmp: List[Any] = [slice(None)] * ndim tmp[i] = j - the_slice: Tuple[slice] = tuple(tmp) + the_slice = tuple(tmp) # Needs to be a tuple of slices not list of slices. if study_count in out.keys(): out[study_count].append(data[the_slice]) diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index b95d08e4..aeef7245 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -1,3 +1,4 @@ +from __future__ import annotations """ .. currentmodule:: arraycontext @@ -46,6 +47,7 @@ Sequence, List, Mapping, + Set, ) import numpy as np @@ -84,7 +86,7 @@ DataWrapper, SizeParam, DictOfNamedArrays, AbstractResultWithNamedArrays, Reshape, Concatenate, NamedArray, IndexRemappingBase, Einsum, InputArgumentBase, AdvancedIndexInNoncontiguousAxes, IndexBase, DataInterface, - Axis) + Axis, ShapeType) from pytato.utils import broadcast_binary_op @@ -115,23 +117,30 @@ def __init__(self, actual_input_shapes: Mapping[str, Tuple[int,...]], def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, - new_expr: Array) -> Tuple[Optional[Tuple[int]], - Optional[Tuple[Axis]]]: - shape_to_prepend: Tuple[int] = tuple([]) - new_axes: Tuple[Axis] = tuple([]) + new_expr: Array) -> Tuple[ShapeType, + Tuple[Axis,...]]: + # Initialize with something for the typing. + shape_to_append: ShapeType= (-1,) + new_axes: Tuple[Axis,...] = (Axis(tags=frozenset()),) if curr_expr.shape == new_expr.shape: - return shape_to_prepend, new_axes + return shape_to_append, new_axes # Now we may need to change. changed = False for i in range(len(new_expr.axes)): axis_tags = list(new_expr.axes[i].tags) + already_added = False for j, tag in enumerate(axis_tags): # Should be relatively few tags on each axis $O(1)$. if isinstance(tag, ParameterStudyAxisTag): new_axes = new_axes + (new_expr.axes[i],) - shape_to_prepend = shape_to_prepend + (new_expr.shape[i],) - return shape_to_prepend, new_axes + shape_to_append = shape_to_append + (new_expr.shape[i],) + if already_added: + raise ValueError("An individual axis may only be tagged with one ParameterStudyAxisTag.") + already_added = True + + # Remove initialized extraneous data + return shape_to_append[1:], new_axes[1:] def map_stack(self, expr: Stack) -> Array: @@ -143,50 +152,49 @@ def map_concatenate(self, expr: Concatenate) -> Array: def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + _, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, - axis=expr.axis + len(new_axes), - axes=new_axes + expr.axes, + axis=expr.axis, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) - axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i - in range(len(expr.axis_permutation))]) # Include the axes we are adding to the system. - axis_permute = tuple([i for i in range(len(prepend_shape))]) + axis_permute + axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) + for i in range(len(postpend_shape))]) return AxisPermutation(array=new_array, axis_permutation=axis_permute, - axes=new_axes + expr.axes, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return type(expr)(new_array, indices=self.rec_idx_or_size_tuple(expr.indices), # May need to modify indices - axes=new_axes + expr.axes, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags = expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) - prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return Reshape(new_array, - newshape = self.rec_idx_or_size_tuple(prepend_shape + expr.newshape), + newshape = self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), order=expr.order, - axes=new_axes + expr.axes, + axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -198,7 +206,7 @@ def map_placeholder(self, expr: Placeholder) -> Array: # We may need to update the size. if expr.shape != self.actual_input_shapes[expr.name]: correct_shape = self.actual_input_shapes[expr.name] - correct_axes = self.actual_input_axes[expr.name] + correct_axes = tuple(self.actual_input_axes[expr.name]) return Placeholder(name=expr.name, shape=self.rec_idx_or_size_tuple(correct_shape), dtype=expr.dtype, @@ -208,30 +216,30 @@ def map_placeholder(self, expr: Placeholder) -> Array: def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. - new_bindings: Mapping[str, Array] = { name: self.rec(bnd) + new_bindings: Dict[str, Array] = { name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} # Determine the new parameter studies that are being conducted. from pytools import unique from pytools.obj_array import flat_obj_array - all_axis_tags: Set[Tag] = set() - studies_by_variable: Mapping[str, Mapping[Tag, bool]] = {} + all_axis_tags: Tuple[ParameterStudyAxisTag,...] = () + studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() studies_by_variable[name] = {} for i in range(len(bnd.axes)): axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) for tag in axis_tags_for_bnd: - studies_by_variable[name][tag] = 1 - all_axis_tags = all_axis_tags.union(axis_tags_for_bnd) + if isinstance(tag, ParameterStudyAxisTag): + # Defense + studies_by_variable[name][tag] = True + all_axis_tags = all_axis_tags + (tag,) - # Freeze the set now. - all_axis_tags = frozenset(all_axis_tags) active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) - axes: Optional[Tuple[Axis]] = tuple([]) - study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {} + axes: Tuple[Axis, ...] = () + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} count = 0 new_shape = expr.shape @@ -251,7 +259,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number) return IndexLambda(expr=scalar_expr, - bindings=type(expr.bindings)(new_bindings), + bindings=immutabledict(new_bindings), shape=new_shape, var_to_reduction_descr=expr.var_to_reduction_descr, dtype=expr.dtype, @@ -274,7 +282,7 @@ def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, index = self.rec(expr.index, studies_by_variable, study_to_axis_number) - new_vars: Tuple[prim.Variable] = tuple([]) + new_vars: Tuple[prim.Variable, ...] = () for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): if key in studies_by_variable[name]: @@ -379,19 +387,25 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()} - input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} + #input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")} + #input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} + input_shapes = {} + input_axes = {} + for i in arg_id_to_descr.keys(): + if hasattr(arg_id_to_descr[i], "shape"): + input_shapes[input_id_to_name_in_program[i]] = arg_id_to_descr[i].shape + input_axes[input_id_to_name_in_program[i]] = arg_id_to_arg[i].axes myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies breakpoint() - dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) breakpoint() - dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays. # Use the normal compiler now. - compiled_func = self._dag_to_compiled_func(dict_of_named_arrays, + compiled_func = self._dag_to_compiled_func(myMapper(pt_dict_of_named_arrays), # Update the arrays + #pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) @@ -411,14 +425,13 @@ def _cut_if_in_param_study(name, arg) -> Array: """ ndim: int = len(arg.shape) newshape = [] - update_axes = [] + update_axes: Tuple[Axis, ...] = (Axis(tags=frozenset()),) for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) if not axis_tags: - update_axes.append(arg.axes[i]) + update_axes = update_axes + (arg.axes[i],) newshape.append(arg.shape[i]) - - update_axes = tuple(update_axes) + update_axes = update_axes[1:] # remove the first one that was placed there for typing. update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) From d69cc00a37e0020a67505756587afc35c991cad8 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 17 Jul 2024 16:00:49 -0500 Subject: [PATCH 11/25] Mypy update. --- arraycontext/parameter_study/__init__.py | 3 +-- arraycontext/parameter_study/transform.py | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index cbe68cc6..46db6770 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -45,7 +45,6 @@ THE SOFTWARE. """ -import abc import sys from typing import ( TYPE_CHECKING, @@ -69,7 +68,7 @@ from pytato.array import Array from pytools import memoize_method -from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag +from pytools.tag import UniqueTag from dataclasses import dataclass diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index aeef7245..3fbc1080 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -71,12 +71,12 @@ from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.context import ArrayT, ArrayContext -from arraycontext.metadata import NameHint from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, _get_arg_id_to_arg_and_arg_id_to_descr, _to_input_for_compiled, - _ary_container_key_stringifier) + _ary_container_key_stringifier, + LeafArrayDescriptor,) from pytato.transform import CopyMapper @@ -109,7 +109,7 @@ class ExpansionMapper(CopyMapper): #def __init__(self, dependency_map: Dict[Array,Tag]): # super().__init__() # self.depends = dependency_map - def __init__(self, actual_input_shapes: Mapping[str, Tuple[int,...]], + def __init__(self, actual_input_shapes: Mapping[str,ShapeType], actual_input_axes: Mapping[str, FrozenSet[Axis]]): super().__init__() self.actual_input_shapes = actual_input_shapes @@ -391,10 +391,10 @@ def _as_dict_of_named_arrays(keys, ary): #input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} input_shapes = {} input_axes = {} - for i in arg_id_to_descr.keys(): - if hasattr(arg_id_to_descr[i], "shape"): - input_shapes[input_id_to_name_in_program[i]] = arg_id_to_descr[i].shape - input_axes[input_id_to_name_in_program[i]] = arg_id_to_arg[i].axes + for key,val in arg_id_to_descr.items(): + if isinstance(val, LeafArrayDescriptor): + input_shapes[input_id_to_name_in_program[key]] = val.shape + input_axes[input_id_to_name_in_program[key]] = arg_id_to_arg[key].axes myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies breakpoint() From 904adfc8d690fb1b4d27a45d64c67fb726068917 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 18 Jul 2024 14:04:13 -0500 Subject: [PATCH 12/25] Add in the mapper for the stack operation. --- arraycontext/parameter_study/transform.py | 229 ++++++++++++++-------- examples/parameter_study.py | 29 ++- 2 files changed, 167 insertions(+), 91 deletions(-) diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 3fbc1080..a95b541a 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -1,4 +1,6 @@ from __future__ import annotations + + """ .. currentmodule:: arraycontext @@ -32,63 +34,52 @@ THE SOFTWARE. """ -import abc -import sys +from dataclasses import dataclass from typing import ( - TYPE_CHECKING, Any, Callable, Dict, FrozenSet, - Optional, - Tuple, - Type, - Union, - Sequence, - List, Mapping, + Sequence, Set, + Tuple, ) import numpy as np -import pytato as pt -import loopy as lp +import pymbolic.primitives as prim from immutabledict import immutabledict - +import loopy as lp +import pytato as pt +from pytato.array import ( + Array, + Axis, + AxisPermutation, + Concatenate, + IndexBase, + IndexLambda, + Placeholder, + Reshape, + Roll, + ShapeType, + Stack, +) from pytato.scalar_expr import IdentityMapper -import pymbolic.primitives as prim - - -from pytools import memoize_method -from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag - -from dataclasses import dataclass - -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context, rec_keyed_map_array_container) - -from arraycontext.container import ArrayContainer, is_array_container_type - -from arraycontext.context import ArrayT, ArrayContext -from arraycontext import PytatoPyOpenCLArrayContext -from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, - _get_arg_id_to_arg_and_arg_id_to_descr, - _to_input_for_compiled, - _ary_container_key_stringifier, - LeafArrayDescriptor,) - - from pytato.transform import CopyMapper +from pytools.tag import Tag, UniqueTag -from pytato.array import ( - Array, IndexLambda, Placeholder, Stack, Roll, AxisPermutation, - DataWrapper, SizeParam, DictOfNamedArrays, AbstractResultWithNamedArrays, - Reshape, Concatenate, NamedArray, IndexRemappingBase, Einsum, - InputArgumentBase, AdvancedIndexInNoncontiguousAxes, IndexBase, DataInterface, - Axis, ShapeType) +from arraycontext import PytatoPyOpenCLArrayContext +from arraycontext.container import is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.impl.pytato.compile import ( + LazilyPyOpenCLCompilingFunctionCaller, + LeafArrayDescriptor, + _ary_container_key_stringifier, + _get_arg_id_to_arg_and_arg_id_to_descr, + _to_input_for_compiled, +) -from pytato.utils import broadcast_binary_op @dataclass(frozen=True) class ParameterStudyAxisTag(UniqueTag): @@ -100,31 +91,31 @@ class ParameterStudyAxisTag(UniqueTag): Currently does not allow multiple variables of different names to be in the same parameter study. """ - #user_param_study_tag: Tag + # user_param_study_tag: Tag axis_num: int axis_size: int + class ExpansionMapper(CopyMapper): - #def __init__(self, dependency_map: Dict[Array,Tag]): + # def __init__(self, dependency_map: Dict[Array,Tag]): # super().__init__() # self.depends = dependency_map - def __init__(self, actual_input_shapes: Mapping[str,ShapeType], + def __init__(self, actual_input_shapes: Mapping[str, ShapeType], actual_input_axes: Mapping[str, FrozenSet[Axis]]): super().__init__() self.actual_input_shapes = actual_input_shapes self.actual_input_axes = actual_input_axes - def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, new_expr: Array) -> Tuple[ShapeType, - Tuple[Axis,...]]: + Tuple[Axis, ...]]: # Initialize with something for the typing. - shape_to_append: ShapeType= (-1,) - new_axes: Tuple[Axis,...] = (Axis(tags=frozenset()),) + shape_to_append: ShapeType = (-1,) + new_axes: Tuple[Axis, ...] = (Axis(tags=frozenset()),) if curr_expr.shape == new_expr.shape: return shape_to_append, new_axes - + # Now we may need to change. changed = False for i in range(len(new_expr.axes)): @@ -142,21 +133,13 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A # Remove initialized extraneous data return shape_to_append[1:], new_axes[1:] - - def map_stack(self, expr: Stack) -> Array: - # TODO: Fix - return super().map_stack(expr) - - def map_concatenate(self, expr: Concatenate) -> Array: - return super().map_concatenate(expr) - def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - _, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + _, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, - axis=expr.axis, + axis=expr.axis, axes=expr.axes + new_axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -167,8 +150,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array) # Include the axes we are adding to the system. axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) - for i in range(len(postpend_shape))]) - + for i in range(len(postpend_shape))]) return AxisPermutation(array=new_array, axis_permutation=axis_permute, @@ -185,14 +167,14 @@ def _map_index_base(self, expr: IndexBase) -> Array: # May need to modify indices axes=expr.axes + new_axes, tags=expr.tags, - non_equality_tags = expr.non_equality_tags) + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, - new_array) + new_array) return Reshape(new_array, - newshape = self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), + newshape=self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), order=expr.order, axes=expr.axes + new_axes, tags=expr.tags, @@ -214,16 +196,101 @@ def map_placeholder(self, expr: Placeholder) -> Array: tags=expr.tags, non_equality_tags=expr.non_equality_tags) + # {{{ Operations with multiple predecessors. + + def map_stack(self, expr: Stack) -> Array: + # TODO: Fix + single_instance_input_shape = expr.arrays[0].shape + new_arrays = tuple(self.rec(arr) for arr in expr.arrays) + + new_axes_for_end: Tuple[Axis,...] = () + active_studies: Set[ParameterStudyAxisTag] = set() + studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag,...]] = {} + + + for ind, array in enumerate(new_arrays): + for axis in array.axes: + axis_tags = axis.tags_of_type(ParameterStudyAxisTag) + if axis_tags: + axis_tags = list(axis_tags) + assert len(axis_tags) == 1 + if array in studies_by_array.keys(): + studies_by_array[array] = studies_by_array[array] + (axis_tags[0],) + else: + studies_by_array[array] = (axis_tags[0],) + + + if axis_tags[0] not in active_studies: + active_studies.add(axis_tags[0]) + new_axes_for_end = new_axes_for_end + (axis,) + breakpoint() + + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} + + new_shape_of_predecessors = single_instance_input_shape + new_axes = expr.axes + + for study in active_studies: + if isinstance(study, ParameterStudyAxisTag): + # Just defensive programming + # The active studies are added to the end of the bindings. + study_to_axis_number[study] = len(new_shape_of_predecessors) + new_shape_of_predecessors = new_shape_of_predecessors + (study.axis_size,) + new_axes = new_axes + (Axis(tags=frozenset((study,))),) + # This assumes that the axis only has 1 tag, + # because there should be no dependence across instances. + + # This is going to be expensive. + + # Now we need to update the expressions. + # Now that we have the appropriate shape, we need to update the input arrays to match. + cp_map = CopyMapper() + corrected_new_arrays: Tuple[Array, ...] = () + for ind, array in enumerate(new_arrays): + tmp = cp_map(array) # Get a copy of the array. + if len(array.axes) < len(new_axes): + # We need to grow the array to the new size. + for study in active_studies: + if study not in studies_by_array[array]: + build:List[Array] = [cp_map(tmp) for _ in range(study.axis_size)] + tmp = Stack(arrays=tuple(build), axis=len(tmp.axes), + axes=tmp.axes + (Axis(tags=frozenset((study,))),), + tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) + elif len(array.axes) > len(new_axes): + raise ValueError(f"Input array is too big. Expected at most: {len(new_axes)} Found: {len(array.axes)} axes.") + + # Now we need to correct to the appropriate shape with an axis permutation. + # These are known to be in the right place. + permute: Tuple[Axis,...] = tuple([i for i in range(len(single_instance_input_shape))]) + + for iaxis, axis in enumerate(tmp.axes): + axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) + if axis_tags: + assert len(axis_tags) == 1 + permute = permute + (study_to_axis_number[axis_tags[0]],) + assert len(permute) == len(new_shape_of_predecessors) + corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, + axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) + + + out = Stack(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) + breakpoint() + + return out + + def map_concatenate(self, expr: Concatenate) -> Array: + return super().map_concatenate(expr) + def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. - new_bindings: Dict[str, Array] = { name: self.rec(bnd) + new_bindings: Dict[str, Array] = {name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} # Determine the new parameter studies that are being conducted. from pytools import unique - from pytools.obj_array import flat_obj_array - all_axis_tags: Tuple[ParameterStudyAxisTag,...] = () + all_axis_tags: Tuple[ParameterStudyAxisTag, ...] = () studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() @@ -233,15 +300,12 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: for tag in axis_tags_for_bnd: if isinstance(tag, ParameterStudyAxisTag): # Defense - studies_by_variable[name][tag] = True + studies_by_variable[name][tag] = True all_axis_tags = all_axis_tags + (tag,) - active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) - axes: Tuple[Axis, ...] = () study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - count = 0 new_shape = expr.shape new_axes = expr.axes @@ -267,6 +331,8 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: tags=expr.tags, non_equality_tags=expr.non_equality_tags) + # }}} Operations with multiple predecessors. + class ParamAxisExpander(IdentityMapper): @@ -281,7 +347,7 @@ def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, # These are the single instance information. index = self.rec(expr.index, studies_by_variable, study_to_axis_number) - + new_vars: Tuple[prim.Variable, ...] = () for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): @@ -314,7 +380,6 @@ class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) - def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: # Update in a subclass if you want. return t_unit @@ -322,7 +387,6 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU # }}} - class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): """ Record a side-effect-free callable :attr:`f` which is initially designed for @@ -387,15 +451,15 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - #input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")} - #input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} + # input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")} + # input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} input_shapes = {} input_axes = {} - for key,val in arg_id_to_descr.items(): + for key, val in arg_id_to_descr.items(): if isinstance(val, LeafArrayDescriptor): input_shapes[input_id_to_name_in_program[key]] = val.shape input_axes[input_id_to_name_in_program[key]] = arg_id_to_arg[key].axes - myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies + myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies breakpoint() pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) @@ -403,9 +467,9 @@ def _as_dict_of_named_arrays(keys, ary): breakpoint() # Use the normal compiler now. - - compiled_func = self._dag_to_compiled_func(myMapper(pt_dict_of_named_arrays), # Update the arrays - #pt_dict_of_named_arrays, + + compiled_func = self._dag_to_compiled_func(myMapper(pt_dict_of_named_arrays), # Update the arrays + # pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) @@ -431,11 +495,12 @@ def _cut_if_in_param_study(name, arg) -> Array: if not axis_tags: update_axes = update_axes + (arg.axes[i],) newshape.append(arg.shape[i]) - update_axes = update_axes[1:] # remove the first one that was placed there for typing. + update_axes = update_axes[1:] # remove the first one that was placed there for typing. update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) + def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): """ Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. diff --git a/examples/parameter_study.py b/examples/parameter_study.py index 78adc2c5..683a7a83 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -1,15 +1,18 @@ from dataclasses import dataclass import numpy as np # for the data types -from pytools.tag import Tag - -from arraycontext.parameter_study import (pack_for_parameter_study, - unpack_parameter_study) -from arraycontext.parameter_study.transform import (ParamStudyPytatoPyOpenCLArrayContext, - ParameterStudyAxisTag) import pyopencl as cl -import pytato as pt + +from arraycontext.parameter_study import ( + pack_for_parameter_study, + unpack_parameter_study, +) +from arraycontext.parameter_study.transform import ( + ParameterStudyAxisTag, + ParamStudyPytatoPyOpenCLArrayContext, +) + ctx = cl.create_some_context(interactive=False) queue = cl.CommandQueue(ctx) @@ -28,20 +31,28 @@ y2 = actx.from_numpy(rng.random(base_shape)) y3 = actx.from_numpy(rng.random(base_shape)) + # Eq: z = x + y # Assumptions: x and y are undergoing independent parameter studies. def rhs(param1, param2): + import pytato as pt + return pt.stack([param1, param2],axis=0) + return param1.stack(param1) return param1 + param2 + @dataclass(frozen=True) class ParameterStudyForX(ParameterStudyAxisTag): pass + @dataclass(frozen=True) class ParameterStudyForY(ParameterStudyAxisTag): pass -# Pack a parameter study of 3 instances for both x and y. +# Pack a parameter study of 3 instances for x and and 4 instances for y. + + packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2) packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3) @@ -50,8 +61,8 @@ class ParameterStudyForY(ParameterStudyAxisTag): # Builds a trace for a single instance of evaluating the RHS and # then converts it to a program which takes our multiple instances of `x` and `y`. output = compiled_rhs(packx, packy) +output_2 = compiled_rhs(x, y) breakpoint() -output_2 = compiled_rhs(x,y) assert output.shape == (15, 5, 3, 4) # Distinct parameter studies. From 6a1e536e06299810ccddabd7649d792be48c193d Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 19 Jul 2024 11:15:01 -0500 Subject: [PATCH 13/25] Update on concatenate and einsum operations. --- arraycontext/parameter_study/transform.py | 143 ++++++++++++++++++---- 1 file changed, 120 insertions(+), 23 deletions(-) diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index a95b541a..dc1d43fa 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -54,9 +54,12 @@ import pytato as pt from pytato.array import ( Array, + AxesT, Axis, AxisPermutation, Concatenate, + Einsum, + EinsumElementwiseAxis, IndexBase, IndexLambda, Placeholder, @@ -109,10 +112,10 @@ def __init__(self, actual_input_shapes: Mapping[str, ShapeType], def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, new_expr: Array) -> Tuple[ShapeType, - Tuple[Axis, ...]]: + AxesT]: # Initialize with something for the typing. shape_to_append: ShapeType = (-1,) - new_axes: Tuple[Axis, ...] = (Axis(tags=frozenset()),) + new_axes: AxesT = (Axis(tags=frozenset()),) if curr_expr.shape == new_expr.shape: return shape_to_append, new_axes @@ -198,15 +201,14 @@ def map_placeholder(self, expr: Placeholder) -> Array: # {{{ Operations with multiple predecessors. - def map_stack(self, expr: Stack) -> Array: - # TODO: Fix - single_instance_input_shape = expr.arrays[0].shape - new_arrays = tuple(self.rec(arr) for arr in expr.arrays) + def _get_active_studies_from_multiple_predecessors(self, new_arrays: Tuple[Array, ...]) -> Tuple[Tuple[Axis, ...], + Set[ParameterStudyAxisTag], + Dict[Array, + Tuple[ParameterStudyAxisTag, ...]]]: - new_axes_for_end: Tuple[Axis,...] = () + new_axes_for_end: Tuple[Axis, ...] = () active_studies: Set[ParameterStudyAxisTag] = set() - studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag,...]] = {} - + studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag, ...]] = {} for ind, array in enumerate(new_arrays): for axis in array.axes: @@ -219,11 +221,18 @@ def map_stack(self, expr: Stack) -> Array: else: studies_by_array[array] = (axis_tags[0],) - if axis_tags[0] not in active_studies: active_studies.add(axis_tags[0]) new_axes_for_end = new_axes_for_end + (axis,) - breakpoint() + + return new_axes_for_end, active_studies, studies_by_array + + def map_stack(self, expr: Stack) -> Array: + # TODO: Fix + single_instance_input_shape = expr.arrays[0].shape + new_arrays = tuple(self.rec(arr) for arr in expr.arrays) + + new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} @@ -247,13 +256,13 @@ def map_stack(self, expr: Stack) -> Array: cp_map = CopyMapper() corrected_new_arrays: Tuple[Array, ...] = () for ind, array in enumerate(new_arrays): - tmp = cp_map(array) # Get a copy of the array. + tmp = cp_map(array) # Get a copy of the array. if len(array.axes) < len(new_axes): # We need to grow the array to the new size. for study in active_studies: if study not in studies_by_array[array]: - build:List[Array] = [cp_map(tmp) for _ in range(study.axis_size)] - tmp = Stack(arrays=tuple(build), axis=len(tmp.axes), + build: Tuple[Array, ...] = tuple([cp_map(tmp) for _ in range(study.axis_size)]) + tmp = Stack(arrays=build, axis=len(tmp.axes), axes=tmp.axes + (Axis(tags=frozenset((study,))),), tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) elif len(array.axes) > len(new_axes): @@ -261,8 +270,8 @@ def map_stack(self, expr: Stack) -> Array: # Now we need to correct to the appropriate shape with an axis permutation. # These are known to be in the right place. - permute: Tuple[Axis,...] = tuple([i for i in range(len(single_instance_input_shape))]) - + permute: Tuple[int, ...] = tuple([i for i in range(len(single_instance_input_shape))]) + for iaxis, axis in enumerate(tmp.axes): axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) if axis_tags: @@ -272,15 +281,67 @@ def map_stack(self, expr: Stack) -> Array: corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) - - out = Stack(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, + return Stack(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, tags=expr.tags, non_equality_tags=expr.non_equality_tags) - breakpoint() - - return out def map_concatenate(self, expr: Concatenate) -> Array: - return super().map_concatenate(expr) + single_instance_input_shape = expr.arrays[0].shape + # Note that one of the axes within the first single_instance_input_shape + # will not match in size across all inputs. + + new_arrays = tuple(self.rec(arr) for arr in expr.arrays) + + new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) + + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} + + new_shape_of_predecessors = single_instance_input_shape + new_axes = expr.axes + + for study in active_studies: + if isinstance(study, ParameterStudyAxisTag): + # Just defensive programming + # The active studies are added to the end of the bindings. + study_to_axis_number[study] = len(new_shape_of_predecessors) + new_shape_of_predecessors = new_shape_of_predecessors + (study.axis_size,) + new_axes = new_axes + (Axis(tags=frozenset((study,))),) + # This assumes that the axis only has 1 tag, + # because there should be no dependence across instances. + + # This is going to be expensive. + + # Now we need to update the expressions. + # Now that we have the appropriate shape, we need to update the input arrays to match. + cp_map = CopyMapper() + corrected_new_arrays: Tuple[Array, ...] = () + for ind, array in enumerate(new_arrays): + tmp = cp_map(array) # Get a copy of the array. + if len(array.axes) < len(new_axes): + # We need to grow the array to the new size. + for study in active_studies: + if study not in studies_by_array[array]: + build: Tuple[Array, ...] = tuple([cp_map(tmp) for _ in range(study.axis_size)]) + tmp = Stack(arrays=build, axis=len(tmp.axes), + axes=tmp.axes + (Axis(tags=frozenset((study,))),), + tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) + elif len(array.axes) > len(new_axes): + raise ValueError(f"Input array is too big. Expected at most: {len(new_axes)} Found: {len(array.axes)} axes.") + + # Now we need to correct to the appropriate shape with an axis permutation. + # These are known to be in the right place. + permute: Tuple[int, ...] = tuple([i for i in range(len(single_instance_input_shape))]) + + for iaxis, axis in enumerate(tmp.axes): + axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) + if axis_tags: + assert len(axis_tags) == 1 + permute = permute + (study_to_axis_number[axis_tags[0]],) + assert len(permute) == len(new_shape_of_predecessors) + corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, + axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) + + return Concatenate(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. @@ -331,6 +392,42 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_einsum(self, expr: Einsum) -> Array: + + new_arrays = tuple([self.rec(arg) for arg in expr.args]) + new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) + + # Access Descriptors hold the Einsum notation. + new_access_descriptors = list(expr.access_descriptors) + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} + + new_shape = expr.shape + + for study in active_studies: + if isinstance(study, ParameterStudyAxisTag): + # Just defensive programming + # The active studies are added to the end. + study_to_axis_number[study] = len(new_shape) + new_shape = new_shape + (study.axis_size,) + + for ind, array in enumerate(new_arrays): + for _, axis in enumerate(array.axes): + axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) + if axis_tags: + assert len(axis_tags) == 1 + new_access_descriptors[ind] = new_access_descriptors[ind] + \ + (EinsumElementwiseAxis(dim=study_to_axis_number[axis_tags[0]]),) + + out = Einsum(tuple(new_access_descriptors), new_arrays, axes=expr.axes + new_axes_for_end, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + index_to_access_descr=expr.index_to_access_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + breakpoint() + return out + + return super().map_einsum(expr) + # }}} Operations with multiple predecessors. @@ -489,7 +586,7 @@ def _cut_if_in_param_study(name, arg) -> Array: """ ndim: int = len(arg.shape) newshape = [] - update_axes: Tuple[Axis, ...] = (Axis(tags=frozenset()),) + update_axes: AxesT = (Axis(tags=frozenset()),) for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) if not axis_tags: From 506dbb9c1cf13b74d5ec63f62b6db742af6ee98f Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 21 Jul 2024 22:36:59 -0500 Subject: [PATCH 14/25] Update the packing and unpacking tests to match the decision to have the new axes at the end. --- arraycontext/parameter_study/__init__.py | 49 +++++++---------------- arraycontext/parameter_study/transform.py | 4 +- examples/parameter_study.py | 3 -- test/test_pytato_parameter_study.py | 30 ++++---------- 4 files changed, 23 insertions(+), 63 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 46db6770..78004d2b 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -49,45 +49,23 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, - FrozenSet, - Optional, - Tuple, - Type, - Union, - Sequence, List, - Iterable, Mapping, + Tuple, + Type, ) import numpy as np -import pytato as pt +import pytato as pt from pytato.array import Array -from pytools import memoize_method -from pytools.tag import UniqueTag - -from dataclasses import dataclass - -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context, rec_keyed_map_array_container) - -from arraycontext.container import ArrayContainer, is_array_container_type - from arraycontext.context import ArrayContext -from arraycontext.metadata import NameHint -from arraycontext import PytatoPyOpenCLArrayContext -from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, - _get_arg_id_to_arg_and_arg_id_to_descr, - _to_input_for_compiled, - _ary_container_key_stringifier) +from arraycontext.parameter_study.transform import ParameterStudyAxisTag -from arraycontext.parameter_study.transform import ExpansionMapper, ParameterStudyAxisTag -# from arraycontext.parameter_study.transform import ExpansionMapper +ParamStudyTagT = Type[ParameterStudyAxisTag] if TYPE_CHECKING: import pyopencl as cl @@ -103,7 +81,7 @@ def pack_for_parameter_study(actx: ArrayContext, - study_name_tag_type: Type[ParameterStudyAxisTag], + study_name_tag_type: ParamStudyTagT, newshape: Tuple[int, ...], *args: Array) -> Array: """ @@ -120,23 +98,24 @@ def pack_for_parameter_study(actx: ArrayContext, orig_shape = args[0].shape out = actx.np.stack(args, axis=len(args[0].shape)) - outshape = tuple(list(orig_shape) + [newshape] ) + outshape = tuple([*list(orig_shape), newshape]) - #if len(newshape) > 1: + # if len(newshape) > 1: # # Reshape the object # out = out.reshape(outshape) - + for i in range(len(orig_shape), len(outshape)): - out = out.with_tagged_axis(i, [study_name_tag_type(i - len(orig_shape), newshape[i-len(orig_shape)])]) + out = out.with_tagged_axis(i, [study_name_tag_type(i - len(orig_shape), + newshape[i-len(orig_shape)])]) return out def unpack_parameter_study(data: Array, - study_name_tag_type: Type[ParameterStudyAxisTag]) -> Mapping[int, + study_name_tag_type: ParamStudyTagT) -> Mapping[int, List[Array]]: """ - Split the data array along the axes which vary according to a ParameterStudyAxisTag - whose name tag is an instance study_name_tag_type. + Split the data array along the axes which vary according to + a ParameterStudyAxisTag whose name tag is an instance study_name_tag_type. output[i] corresponds to the values associated with the ith parameter study that uses the variable name :arg: `study_name_tag_type`. diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index dc1d43fa..003abbf3 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -556,7 +556,7 @@ def _as_dict_of_named_arrays(keys, ary): if isinstance(val, LeafArrayDescriptor): input_shapes[input_id_to_name_in_program[key]] = val.shape input_axes[input_id_to_name_in_program[key]] = arg_id_to_arg[key].axes - myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies + my_expansion_map = ExpansionMapper(input_shapes, input_axes) # Get the dependencies breakpoint() pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) @@ -565,7 +565,7 @@ def _as_dict_of_named_arrays(keys, ary): # Use the normal compiler now. - compiled_func = self._dag_to_compiled_func(myMapper(pt_dict_of_named_arrays), # Update the arrays + compiled_func = self._dag_to_compiled_func(my_expansion_map(pt_dict_of_named_arrays), # pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, diff --git a/examples/parameter_study.py b/examples/parameter_study.py index 683a7a83..83acecad 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -35,9 +35,6 @@ # Eq: z = x + y # Assumptions: x and y are undergoing independent parameter studies. def rhs(param1, param2): - import pytato as pt - return pt.stack([param1, param2],axis=0) - return param1.stack(param1) return param1 + param2 diff --git a/test/test_pytato_parameter_study.py b/test/test_pytato_parameter_study.py index c4f9a4ed..5f112a78 100644 --- a/test/test_pytato_parameter_study.py +++ b/test/test_pytato_parameter_study.py @@ -73,23 +73,6 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( # {{{ dummy tag types -class FooTag(Tag): - """ - Foo - """ - - -class BarTag(Tag): - """ - Bar - """ - - -class BazTag(Tag): - """ - Baz - """ - class ParamStudy1(ParameterStudyAxisTag): """ 1st parameter study. @@ -132,16 +115,17 @@ def test_pack_for_parameter_study(actx_factory): def rhs(a,b): return a + b + # Adding to the end. pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) - assert pack_x.shape == (4,15,5) + assert pack_x.shape == (15,5,4) pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0,y1, y2,y3,y4) - assert pack_y.shape == (5,15,5) + assert pack_y.shape == (15,5,5) for i in range(3): axis_tags = pack_x.axes[i].tags_of_type(ParamStudy1) second_tags = pack_x.axes[i].tags_of_type(ParamStudy2) - if i == 0: + if i == 2: assert axis_tags else: assert not axis_tags @@ -176,16 +160,16 @@ def rhs(a,b): return a + b pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) - assert pack_x.shape == (4,15,5) + assert pack_x.shape == (15,5,4) pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0,y1, y2,y3,y4) - assert pack_y.shape == (5,15,5) + assert pack_y.shape == (15,5,5) compiled_rhs = actx.compile(rhs) output = compiled_rhs(pack_x, pack_y) - assert output.shape(4,5,15,5) + assert output.shape(15,5,4,5) output_x = unpack_parameter_study(output, ParamStudy1) assert len(output_x) == 1 # Only 1 study associated with this variable. From 81b39bc1f058adae0ee72f0e2ec1e5b373fb1222 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 08:24:56 -0500 Subject: [PATCH 15/25] Add an advection example. --- arraycontext/parameter_study/transform.py | 236 ++++++++++------------ examples/advection.py | 81 ++++++++ 2 files changed, 184 insertions(+), 133 deletions(-) create mode 100644 examples/advection.py diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 003abbf3..19921f56 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -44,6 +44,7 @@ Sequence, Set, Tuple, + Union, ) import numpy as np @@ -101,16 +102,14 @@ class ParameterStudyAxisTag(UniqueTag): class ExpansionMapper(CopyMapper): - # def __init__(self, dependency_map: Dict[Array,Tag]): - # super().__init__() - # self.depends = dependency_map def __init__(self, actual_input_shapes: Mapping[str, ShapeType], actual_input_axes: Mapping[str, FrozenSet[Axis]]): super().__init__() self.actual_input_shapes = actual_input_shapes self.actual_input_axes = actual_input_axes - def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, + def single_predecessor_updates(self, + curr_expr: Array, new_expr: Array) -> Tuple[ShapeType, AxesT]: # Initialize with something for the typing. @@ -120,17 +119,17 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A return shape_to_append, new_axes # Now we may need to change. - changed = False for i in range(len(new_expr.axes)): axis_tags = list(new_expr.axes[i].tags) already_added = False - for j, tag in enumerate(axis_tags): + for _j, tag in enumerate(axis_tags): # Should be relatively few tags on each axis $O(1)$. if isinstance(tag, ParameterStudyAxisTag): new_axes = new_axes + (new_expr.axes[i],) shape_to_append = shape_to_append + (new_expr.shape[i],) if already_added: - raise ValueError("An individual axis may only be tagged with one ParameterStudyAxisTag.") + raise ValueError("An individual axis may only be " + + "tagged with one ParameterStudyAxisTag.") already_added = True # Remove initialized extraneous data @@ -138,7 +137,7 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - _, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + _, new_axes = self.single_predecessor_updates(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, @@ -149,7 +148,7 @@ def map_roll(self, expr: Roll) -> Array: def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) # Include the axes we are adding to the system. axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) @@ -163,8 +162,8 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: def _map_index_base(self, expr: IndexBase) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, - new_array) + _, new_axes = self.single_predecessor_updates(expr.array, + new_array) return type(expr)(new_array, indices=self.rec_idx_or_size_tuple(expr.indices), # May need to modify indices @@ -174,10 +173,11 @@ def _map_index_base(self, expr: IndexBase) -> Array: def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, - new_array) + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, + new_array) return Reshape(new_array, - newshape=self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), + newshape=self.rec_idx_or_size_tuple(expr.newshape + + postpend_shape), order=expr.order, axes=expr.axes + new_axes, tags=expr.tags, @@ -201,109 +201,71 @@ def map_placeholder(self, expr: Placeholder) -> Array: # {{{ Operations with multiple predecessors. - def _get_active_studies_from_multiple_predecessors(self, new_arrays: Tuple[Array, ...]) -> Tuple[Tuple[Axis, ...], - Set[ParameterStudyAxisTag], - Dict[Array, - Tuple[ParameterStudyAxisTag, ...]]]: + def _studies_from_multiple_pred(self, + new_arrays: Tuple[Array, ...]) -> Tuple[AxesT, + Set[ParameterStudyAxisTag], + Dict[Array, Tuple[ParameterStudyAxisTag, ...]]]: new_axes_for_end: Tuple[Axis, ...] = () - active_studies: Set[ParameterStudyAxisTag] = set() + cur_studies: Set[ParameterStudyAxisTag] = set() studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag, ...]] = {} - for ind, array in enumerate(new_arrays): + for _ind, array in enumerate(new_arrays): for axis in array.axes: axis_tags = axis.tags_of_type(ParameterStudyAxisTag) if axis_tags: axis_tags = list(axis_tags) assert len(axis_tags) == 1 if array in studies_by_array.keys(): - studies_by_array[array] = studies_by_array[array] + (axis_tags[0],) + studies_by_array[array] = studies_by_array[array] + \ + (axis_tags[0],) else: studies_by_array[array] = (axis_tags[0],) - if axis_tags[0] not in active_studies: - active_studies.add(axis_tags[0]) + if axis_tags[0] not in cur_studies: + cur_studies.add(axis_tags[0]) new_axes_for_end = new_axes_for_end + (axis,) - return new_axes_for_end, active_studies, studies_by_array + return new_axes_for_end, cur_studies, studies_by_array def map_stack(self, expr: Stack) -> Array: - # TODO: Fix - single_instance_input_shape = expr.arrays[0].shape - new_arrays = tuple(self.rec(arr) for arr in expr.arrays) - - new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) - - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape_of_predecessors = single_instance_input_shape - new_axes = expr.axes - - for study in active_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end of the bindings. - study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = new_shape_of_predecessors + (study.axis_size,) - new_axes = new_axes + (Axis(tags=frozenset((study,))),) - # This assumes that the axis only has 1 tag, - # because there should be no dependence across instances. - - # This is going to be expensive. + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - # Now we need to update the expressions. - # Now that we have the appropriate shape, we need to update the input arrays to match. - cp_map = CopyMapper() - corrected_new_arrays: Tuple[Array, ...] = () - for ind, array in enumerate(new_arrays): - tmp = cp_map(array) # Get a copy of the array. - if len(array.axes) < len(new_axes): - # We need to grow the array to the new size. - for study in active_studies: - if study not in studies_by_array[array]: - build: Tuple[Array, ...] = tuple([cp_map(tmp) for _ in range(study.axis_size)]) - tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=tmp.axes + (Axis(tags=frozenset((study,))),), - tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) - elif len(array.axes) > len(new_axes): - raise ValueError(f"Input array is too big. Expected at most: {len(new_axes)} Found: {len(array.axes)} axes.") - - # Now we need to correct to the appropriate shape with an axis permutation. - # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(single_instance_input_shape))]) + return Stack(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - for iaxis, axis in enumerate(tmp.axes): - axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) - if axis_tags: - assert len(axis_tags) == 1 - permute = permute + (study_to_axis_number[axis_tags[0]],) - assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, - axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) + def map_concatenate(self, expr: Concatenate) -> Array: + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - return Stack(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, - tags=expr.tags, non_equality_tags=expr.non_equality_tags) + return Concatenate(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - def map_concatenate(self, expr: Concatenate) -> Array: - single_instance_input_shape = expr.arrays[0].shape - # Note that one of the axes within the first single_instance_input_shape - # will not match in size across all inputs. + def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[Array, ...], + AxesT]: + sing_inst_in_shape = expr.arrays[0].shape new_arrays = tuple(self.rec(arr) for arr in expr.arrays) - new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) + _, cur_studies, studies_by_array = self._studies_from_multiple_pred(new_arrays) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - new_shape_of_predecessors = single_instance_input_shape + new_shape_of_predecessors = sing_inst_in_shape new_axes = expr.axes - for study in active_studies: + for study in cur_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming # The active studies are added to the end of the bindings. study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = new_shape_of_predecessors + (study.axis_size,) + new_shape_of_predecessors = new_shape_of_predecessors + \ + (study.axis_size,) new_axes = new_axes + (Axis(tags=frozenset((study,))),) # This assumes that the axis only has 1 tag, # because there should be no dependence across instances. @@ -311,42 +273,51 @@ def map_concatenate(self, expr: Concatenate) -> Array: # This is going to be expensive. # Now we need to update the expressions. - # Now that we have the appropriate shape, we need to update the input arrays to match. + # Now that we have the appropriate shape, + # we need to update the input arrays to match. + cp_map = CopyMapper() corrected_new_arrays: Tuple[Array, ...] = () - for ind, array in enumerate(new_arrays): + for _, array in enumerate(new_arrays): tmp = cp_map(array) # Get a copy of the array. if len(array.axes) < len(new_axes): # We need to grow the array to the new size. - for study in active_studies: + for study in cur_studies: if study not in studies_by_array[array]: - build: Tuple[Array, ...] = tuple([cp_map(tmp) for _ in range(study.axis_size)]) + build: Tuple[Array, ...] = tuple([cp_map(tmp) for + _ in range(study.axis_size)]) tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=tmp.axes + (Axis(tags=frozenset((study,))),), - tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) + axes=tmp.axes + + (Axis(tags=frozenset((study,))),), + tags=tmp.tags, + non_equality_tags=tmp.non_equality_tags) elif len(array.axes) > len(new_axes): - raise ValueError(f"Input array is too big. Expected at most: {len(new_axes)} Found: {len(array.axes)} axes.") + raise ValueError("Input array is too big. " + \ + f"Expected at most: {len(new_axes)} " + \ + f"Found: {len(array.axes)} axes.") # Now we need to correct to the appropriate shape with an axis permutation. # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(single_instance_input_shape))]) + permute: Tuple[int, ...] = tuple([i for i in range(len(sing_inst_in_shape))]) - for iaxis, axis in enumerate(tmp.axes): + for _, axis in enumerate(tmp.axes): axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) if axis_tags: assert len(axis_tags) == 1 permute = permute + (study_to_axis_number[axis_tags[0]],) assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, - axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) + corrected_new_arrays = corrected_new_arrays + \ + (AxisPermutation(tmp, permute, tags=tmp.tags, + axes=tmp.axes, + non_equality_tags=tmp.non_equality_tags),) - return Concatenate(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, - tags=expr.tags, non_equality_tags=expr.non_equality_tags) + return corrected_new_arrays, new_axes def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. new_bindings: Dict[str, Array] = {name: self.rec(bnd) - for name, bnd in sorted(expr.bindings.items())} + for name, bnd in + sorted(expr.bindings.items())} # Determine the new parameter studies that are being conducted. from pytools import unique @@ -364,13 +335,13 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: studies_by_variable[name][tag] = True all_axis_tags = all_axis_tags + (tag,) - active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) + cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} new_shape = expr.shape new_axes = expr.axes - for study in active_studies: + for study in cur_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming # The active studies are added to the end of the bindings. @@ -381,7 +352,8 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # because there should be no dependence across instances. # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number) + scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, + study_to_axis_number) return IndexLambda(expr=scalar_expr, bindings=immutabledict(new_bindings), @@ -395,7 +367,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: def map_einsum(self, expr: Einsum) -> Array: new_arrays = tuple([self.rec(arg) for arg in expr.args]) - new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) + new_axes_for_end, cur_studies, _ = self._studies_from_multiple_pred(new_arrays) # Access Descriptors hold the Einsum notation. new_access_descriptors = list(expr.access_descriptors) @@ -403,7 +375,7 @@ def map_einsum(self, expr: Einsum) -> Array: new_shape = expr.shape - for study in active_studies: + for study in cur_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming # The active studies are added to the end. @@ -418,22 +390,21 @@ def map_einsum(self, expr: Einsum) -> Array: new_access_descriptors[ind] = new_access_descriptors[ind] + \ (EinsumElementwiseAxis(dim=study_to_axis_number[axis_tags[0]]),) - out = Einsum(tuple(new_access_descriptors), new_arrays, axes=expr.axes + new_axes_for_end, + return Einsum(tuple(new_access_descriptors), new_arrays, + axes=expr.axes + new_axes_for_end, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, index_to_access_descr=expr.index_to_access_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) - breakpoint() - return out - - return super().map_einsum(expr) # }}} Operations with multiple predecessors. class ParamAxisExpander(IdentityMapper): - def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, Mapping[ParameterStudyAxisTag, bool]], + def map_subscript(self, expr: prim.Subscript, + studies_by_variable: Mapping[str, + Mapping[ParameterStudyAxisTag, bool]], study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): # We know that we are not changing the variable that we are indexing into. # This is stored in the aggregate member of the class Subscript. @@ -447,9 +418,10 @@ def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, new_vars: Tuple[prim.Variable, ...] = () - for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): + for key, num in sorted(study_to_axis_number.items(), + key=lambda item: item[1]): if key in studies_by_variable[name]: - new_vars = new_vars + (prim.Variable(f"_{study_to_axis_number[key]}"),) + new_vars = new_vars + (prim.Variable(f"_{num}"),) if isinstance(index, tuple): index = index + new_vars @@ -477,7 +449,8 @@ class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) - def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + def transform_loopy_program(self, + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: # Update in a subclass if you want. return t_unit @@ -497,10 +470,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s function application on *args*. - Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, it is compiled - to a :mod:`pytato` DAG that would apply - :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` with *args* in a lazy-sense. - The intermediary pytato DAG for *args* is memoized in *self*. + Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, + it is compiled to a :mod:`pytato` DAG that would apply + :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` + with *args* in a lazy-sense. The intermediary pytato DAG for *args* is + memoized in *self*. """ arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) @@ -519,10 +493,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" for arg_id in arg_id_to_arg} - output_template = self.f( - *[_get_f_placeholder_args_for_param_study(arg, iarg, - input_id_to_name_in_program, self.actx) - for iarg, arg in enumerate(args)], + placeholder_args = [_get_f_placeholder_args_for_param_study(arg, iarg, + input_id_to_name_in_program, self.actx) + for iarg, arg in enumerate(args)] + output_template = self.f(*placeholder_args, **{kw: _get_f_placeholder_args_for_param_study(arg, kw, input_id_to_name_in_program, self.actx) @@ -548,30 +522,25 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - # input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")} - # input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} input_shapes = {} input_axes = {} for key, val in arg_id_to_descr.items(): if isinstance(val, LeafArrayDescriptor): input_shapes[input_id_to_name_in_program[key]] = val.shape input_axes[input_id_to_name_in_program[key]] = arg_id_to_arg[key].axes - my_expansion_map = ExpansionMapper(input_shapes, input_axes) # Get the dependencies - breakpoint() + expand_map = ExpansionMapper(input_shapes, input_axes) + # Get the dependencies - pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) - - breakpoint() + sing_inst_outs = pt.make_dict_of_named_arrays(dict_of_named_arrays) # Use the normal compiler now. - compiled_func = self._dag_to_compiled_func(my_expansion_map(pt_dict_of_named_arrays), + compiled_func = self._dag_to_compiled_func(expand_map(sing_inst_outs), # pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) - breakpoint() self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) @@ -582,7 +551,7 @@ def _cut_if_in_param_study(name, arg) -> Array: if it is tagged with a `ParameterStudyAxisTag` to ensure the survival of the information those tags will be converted to temporary Array Tags of the same type. The placeholder will not - have the axes marked with a `ParameterStudyAxisTag` tag. + have the axes marked with a `ParameterStudyAxisTag` tag. """ ndim: int = len(arg.shape) newshape = [] @@ -592,7 +561,8 @@ def _cut_if_in_param_study(name, arg) -> Array: if not axis_tags: update_axes = update_axes + (arg.axes[i],) newshape.append(arg.shape[i]) - update_axes = update_axes[1:] # remove the first one that was placed there for typing. + # remove the first one that was placed there for typing. + update_axes = update_axes[1:] update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) @@ -600,10 +570,10 @@ def _cut_if_in_param_study(name, arg) -> Array: def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): """ - Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the placeholder version of an argument to :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. - + Note this will modify the shape of the placeholder to remove any parameter study axes until the trace can be completed. diff --git a/examples/advection.py b/examples/advection.py new file mode 100644 index 00000000..ab4405d5 --- /dev/null +++ b/examples/advection.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass + +import numpy as np # for the data types + +import pyopencl as cl + +from arraycontext.parameter_study import ( + pack_for_parameter_study, + unpack_parameter_study, +) +from arraycontext.parameter_study.transform import ( + ParameterStudyAxisTag, + ParamStudyPytatoPyOpenCLArrayContext, +) + + +ctx = cl.create_some_context(interactive=False) +queue = cl.CommandQueue(ctx) +actx = ParamStudyPytatoPyOpenCLArrayContext(queue) + + + +@dataclass(frozen=True) +class ParameterStudyForX(ParameterStudyAxisTag): + pass + + +@dataclass(frozen=True) +class ParameterStudyForY(ParameterStudyAxisTag): + pass + +def test_one_time_step_advection(): + + from arraycontext.impl.pytato import _BasePytatoArrayContext + if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): + pytest.skip("only parameter study array contexts are supported") + + import numpy as np + seed = 12345 + rng = np.random.default_rng(seed) + + base_shape = np.prod((15, 5)) + x0 = actx.from_numpy(rng.random(base_shape)) + x1 = actx.from_numpy(rng.random(base_shape)) + x2 = actx.from_numpy(rng.random(base_shape)) + x3 = actx.from_numpy(rng.random(base_shape)) + + + speed_shape = (1,) + y0 = actx.from_numpy(rng.random(speed_shape)) + y1 = actx.from_numpy(rng.random(speed_shape)) + y2 = actx.from_numpy(rng.random(speed_shape)) + y3 = actx.from_numpy(rng.random(speed_shape)) + + + ht = 0.0001 + hx = 0.005 + inds = actx.np.arange(base_shape, dtype=int) + Kp1 = actx.np.roll(inds, -1) + Km1 = actx.np.roll(inds, 1) + + def rhs(fields, wave_speed): + # 2nd order in space finite difference + return fields + wave_speed * (-1) * (ht / (2 * hx)) * \ + (fields[Kp1] - fields[Km1]) + + pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) + assert pack_x.shape == (75,4) + + pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0,y1, y2,y3) + assert pack_y.shape == (1,4) + + compiled_rhs = actx.compile(rhs) + + output = compiled_rhs(pack_x, pack_y) + + assert output.shape(75,4) + + output_x = unpack_parameter_study(output, ParamStudy1) + assert len(output_x) == 1 # Only 1 study associated with this variable. + assert len(output_x[0]) == 4 # 4 inputs for the parameter study. From d0e806be5c0d31ef31f54d8e60dacbdd22531766 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 09:35:47 -0500 Subject: [PATCH 16/25] Fix formatting. --- arraycontext/parameter_study/__init__.py | 3 +- arraycontext/parameter_study/transform.py | 108 +++++++++++----------- 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 78004d2b..cffa27ed 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -58,7 +58,6 @@ import numpy as np -import pytato as pt from pytato.array import Array from arraycontext.context import ArrayContext @@ -98,7 +97,7 @@ def pack_for_parameter_study(actx: ArrayContext, orig_shape = args[0].shape out = actx.np.stack(args, axis=len(args[0].shape)) - outshape = tuple([*list(orig_shape), newshape]) + outshape = *orig_shape, newshape # if len(newshape) > 1: # # Reshape the object diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 19921f56..b945aae8 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -85,6 +85,9 @@ ) +ArraysT = Tuple[Array, ...] + + @dataclass(frozen=True) class ParameterStudyAxisTag(UniqueTag): """ @@ -100,6 +103,9 @@ class ParameterStudyAxisTag(UniqueTag): axis_size: int +StudiesT = Tuple[ParameterStudyAxisTag, ...] + + class ExpansionMapper(CopyMapper): def __init__(self, actual_input_shapes: Mapping[str, ShapeType], @@ -108,8 +114,7 @@ def __init__(self, actual_input_shapes: Mapping[str, ShapeType], self.actual_input_shapes = actual_input_shapes self.actual_input_axes = actual_input_axes - def single_predecessor_updates(self, - curr_expr: Array, + def single_predecessor_updates(self, curr_expr: Array, new_expr: Array) -> Tuple[ShapeType, AxesT]: # Initialize with something for the typing. @@ -125,8 +130,8 @@ def single_predecessor_updates(self, for _j, tag in enumerate(axis_tags): # Should be relatively few tags on each axis $O(1)$. if isinstance(tag, ParameterStudyAxisTag): - new_axes = new_axes + (new_expr.axes[i],) - shape_to_append = shape_to_append + (new_expr.shape[i],) + new_axes = *new_axes, new_expr.axes[i], + shape_to_append = *shape_to_append, new_expr.shape[i], if already_added: raise ValueError("An individual axis may only be " + "tagged with one ParameterStudyAxisTag.") @@ -137,8 +142,7 @@ def single_predecessor_updates(self, def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - _, new_axes = self.single_predecessor_updates(expr.array, - new_array) + _, new_axes = self.single_predecessor_updates(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, axis=expr.axis, @@ -148,8 +152,7 @@ def map_roll(self, expr: Roll) -> Array: def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.single_predecessor_updates(expr.array, - new_array) + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) # Include the axes we are adding to the system. axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) for i in range(len(postpend_shape))]) @@ -161,9 +164,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: + breakpoint() new_array = self.rec(expr.array) - _, new_axes = self.single_predecessor_updates(expr.array, - new_array) + _, new_axes = self.single_predecessor_updates(expr.array, new_array) return type(expr)(new_array, indices=self.rec_idx_or_size_tuple(expr.indices), # May need to modify indices @@ -173,10 +176,9 @@ def _map_index_base(self, expr: IndexBase) -> Array: def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.single_predecessor_updates(expr.array, - new_array) + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) return Reshape(new_array, - newshape=self.rec_idx_or_size_tuple(expr.newshape + + newshape=self.rec_idx_or_size_tuple(expr.newshape + \ postpend_shape), order=expr.order, axes=expr.axes + new_axes, @@ -202,13 +204,14 @@ def map_placeholder(self, expr: Placeholder) -> Array: # {{{ Operations with multiple predecessors. def _studies_from_multiple_pred(self, - new_arrays: Tuple[Array, ...]) -> Tuple[AxesT, - Set[ParameterStudyAxisTag], - Dict[Array, Tuple[ParameterStudyAxisTag, ...]]]: + new_arrays: ArraysT) -> Tuple[AxesT, + Set[ParameterStudyAxisTag], + Dict[Array, + StudiesT]]: - new_axes_for_end: Tuple[Axis, ...] = () + new_axes_for_end: AxesT = () cur_studies: Set[ParameterStudyAxisTag] = set() - studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag, ...]] = {} + studies_by_array: Dict[Array, StudiesT] = {} for _ind, array in enumerate(new_arrays): for axis in array.axes: @@ -224,7 +227,7 @@ def _studies_from_multiple_pred(self, if axis_tags[0] not in cur_studies: cur_studies.add(axis_tags[0]) - new_axes_for_end = new_axes_for_end + (axis,) + new_axes_for_end = *new_axes_for_end, axis return new_axes_for_end, cur_studies, studies_by_array @@ -246,17 +249,17 @@ def map_concatenate(self, expr: Concatenate) -> Array: tags=expr.tags, non_equality_tags=expr.non_equality_tags) - def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[Array, ...], - AxesT]: + def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT, + AxesT]: - sing_inst_in_shape = expr.arrays[0].shape + one_inst_in_shape = expr.arrays[0].shape new_arrays = tuple(self.rec(arr) for arr in expr.arrays) _, cur_studies, studies_by_array = self._studies_from_multiple_pred(new_arrays) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - new_shape_of_predecessors = sing_inst_in_shape + new_shape_of_predecessors = one_inst_in_shape new_axes = expr.axes for study in cur_studies: @@ -264,9 +267,9 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[ # Just defensive programming # The active studies are added to the end of the bindings. study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = new_shape_of_predecessors + \ + new_shape_of_predecessors = *new_shape_of_predecessors, \ (study.axis_size,) - new_axes = new_axes + (Axis(tags=frozenset((study,))),) + new_axes = *new_axes, Axis(tags=frozenset((study,))), # This assumes that the axis only has 1 tag, # because there should be no dependence across instances. @@ -275,41 +278,40 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[ # Now we need to update the expressions. # Now that we have the appropriate shape, # we need to update the input arrays to match. - + cp_map = CopyMapper() - corrected_new_arrays: Tuple[Array, ...] = () + corrected_new_arrays: ArraysT = () for _, array in enumerate(new_arrays): tmp = cp_map(array) # Get a copy of the array. if len(array.axes) < len(new_axes): # We need to grow the array to the new size. for study in cur_studies: if study not in studies_by_array[array]: - build: Tuple[Array, ...] = tuple([cp_map(tmp) for + build: ArraysT = tuple([cp_map(tmp) for _ in range(study.axis_size)]) tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=tmp.axes + - (Axis(tags=frozenset((study,))),), + axes=(*tmp.axes, Axis(tags=frozenset((study,)))), tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) elif len(array.axes) > len(new_axes): - raise ValueError("Input array is too big. " + \ - f"Expected at most: {len(new_axes)} " + \ + raise ValueError("Input array is too big. " + + f"Expected at most: {len(new_axes)} " + f"Found: {len(array.axes)} axes.") # Now we need to correct to the appropriate shape with an axis permutation. # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(sing_inst_in_shape))]) + permute: Tuple[int, ...] = tuple([i for i in range(len(one_inst_in_shape))]) for _, axis in enumerate(tmp.axes): axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) if axis_tags: assert len(axis_tags) == 1 - permute = permute + (study_to_axis_number[axis_tags[0]],) + permute = *permute, study_to_axis_number[axis_tags[0]], assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = corrected_new_arrays + \ - (AxisPermutation(tmp, permute, tags=tmp.tags, + corrected_new_arrays = *corrected_new_arrays, \ + AxisPermutation(tmp, permute, tags=tmp.tags, axes=tmp.axes, - non_equality_tags=tmp.non_equality_tags),) + non_equality_tags=tmp.non_equality_tags), return corrected_new_arrays, new_axes @@ -322,18 +324,18 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # Determine the new parameter studies that are being conducted. from pytools import unique - all_axis_tags: Tuple[ParameterStudyAxisTag, ...] = () - studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {} + all_axis_tags: StudiesT = () + varname_to_studies: Dict[str, Dict[UniqueTag, bool]] = {} for name, bnd in sorted(new_bindings.items()): axis_tags_for_bnd: Set[Tag] = set() - studies_by_variable[name] = {} + varname_to_studies[name] = {} for i in range(len(bnd.axes)): axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) for tag in axis_tags_for_bnd: if isinstance(tag, ParameterStudyAxisTag): # Defense - studies_by_variable[name][tag] = True - all_axis_tags = all_axis_tags + (tag,) + varname_to_studies[name][tag] = True + all_axis_tags = *all_axis_tags, tag, cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} @@ -346,13 +348,13 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # Just defensive programming # The active studies are added to the end of the bindings. study_to_axis_number[study] = len(new_shape) - new_shape = new_shape + (study.axis_size,) - new_axes = new_axes + (Axis(tags=frozenset((study,))),) + new_shape = *new_shape, study.axis_size, + new_axes = *new_axes, Axis(tags=frozenset((study,))), # This assumes that the axis only has 1 tag, # because there should be no dependence across instances. # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, + scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies, study_to_axis_number) return IndexLambda(expr=scalar_expr, @@ -380,7 +382,7 @@ def map_einsum(self, expr: Einsum) -> Array: # Just defensive programming # The active studies are added to the end. study_to_axis_number[study] = len(new_shape) - new_shape = new_shape + (study.axis_size,) + new_shape = *new_shape, study.axis_size, for ind, array in enumerate(new_arrays): for _, axis in enumerate(array.axes): @@ -403,25 +405,25 @@ def map_einsum(self, expr: Einsum) -> Array: class ParamAxisExpander(IdentityMapper): def map_subscript(self, expr: prim.Subscript, - studies_by_variable: Mapping[str, - Mapping[ParameterStudyAxisTag, bool]], + varname_to_studies: Mapping[str, + Mapping[ParameterStudyAxisTag, bool]], study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): # We know that we are not changing the variable that we are indexing into. # This is stored in the aggregate member of the class Subscript. # We only need to modify the indexing which is stored in the index member. name = expr.aggregate.name - if name in studies_by_variable.keys(): + if name in varname_to_studies.keys(): # These are the single instance information. - index = self.rec(expr.index, studies_by_variable, + index = self.rec(expr.index, varname_to_studies, study_to_axis_number) new_vars: Tuple[prim.Variable, ...] = () for key, num in sorted(study_to_axis_number.items(), key=lambda item: item[1]): - if key in studies_by_variable[name]: - new_vars = new_vars + (prim.Variable(f"_{num}"),) + if key in varname_to_studies[name]: + new_vars = *new_vars, prim.Variable(f"_{num}"), if isinstance(index, tuple): index = index + new_vars @@ -559,7 +561,7 @@ def _cut_if_in_param_study(name, arg) -> Array: for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) if not axis_tags: - update_axes = update_axes + (arg.axes[i],) + update_axes = *update_axes, arg.axes[i], newshape.append(arg.shape[i]) # remove the first one that was placed there for typing. update_axes = update_axes[1:] From 28e5860d1e21a1965ca86c5cba8465330e6f4e7d Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 11:05:49 -0500 Subject: [PATCH 17/25] Move the actual transform to pytato. --- arraycontext/__init__.py | 1 - arraycontext/parameter_study/__init__.py | 227 +++++++- arraycontext/parameter_study/transform.py | 602 ---------------------- examples/advection.py | 32 +- examples/parameter_study.py | 11 +- 5 files changed, 230 insertions(+), 643 deletions(-) delete mode 100644 arraycontext/parameter_study/transform.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index f7855170..4ac793c8 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -89,7 +89,6 @@ ) from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag from .parameter_study import pack_for_parameter_study, unpack_parameter_study -from .parameter_study.transform import ParamStudyPytatoPyOpenCLArrayContext __all__ = ( "Array", diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index cffa27ed..6b3ea9d6 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -1,3 +1,6 @@ +from future import __annotations__ + + """ .. currentmodule:: arraycontext @@ -49,6 +52,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, Mapping, @@ -58,17 +62,28 @@ import numpy as np -from pytato.array import Array +import loopy as lp +from pytato.array import (Array, make_placeholder as make_placeholder, + make_dict_of_named_arrays) + +from pytato.transform.parameter_study import ParameterStudyAxisTag +from pytools.tag import Tag, UniqueTag as UniqueTag from arraycontext.context import ArrayContext -from arraycontext.parameter_study.transform import ParameterStudyAxisTag +from arraycontext.container import ArrayContainer, is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext +from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, + _to_input_for_compiled) +ArraysT = Tuple[Array, ...] +StudiesT = Tuple[ParameterStudyAxisTag, ...] ParamStudyTagT = Type[ParameterStudyAxisTag] if TYPE_CHECKING: import pyopencl as cl - import pytato + import pytato as pytato if getattr(sys, "_BUILDING_SPHINX_DOCS", False): import pyopencl as cl @@ -79,33 +94,208 @@ logger = logging.getLogger(__name__) +# {{{ ParamStudyPytatoPyOpenCLArrayContext + + +class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): + """ + A derived class for PytatoPyOpenCLArrayContext updated for the + purpose of enabling parameter studies and uncertainty quantification. + + .. automethod:: __init__ + + .. automethod:: transform_dag + + .. automethod:: compile + """ + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + return ParamStudyLazyPyOpenCLFunctionCaller(self, f) + + def transform_loopy_program(self, + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + # Update in a subclass if you want. + return t_unit + +# }}} + + +class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): + """ + Record a side-effect-free callable :attr:`f` which is initially designed for + to be called multiple times with different data. This class will update the + signature to allow :attr:`f` to be called once with the data for multiple + instances. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s + function application on *args*. + + Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, + it is compiled to a :mod:`pytato` DAG that would apply + :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` + with *args* in a lazy-sense. The intermediary pytato DAG for *args* is + memoized in *self*. + """ + arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( + args, kwargs) + + try: + compiled_f = self.program_cache[arg_id_to_descr] + except KeyError: + pass + else: + # On a cache hit we do not need to modify anything. + return compiled_f(arg_id_to_arg) + + dict_of_named_arrays = {} + output_id_to_name_in_program = {} + input_id_to_name_in_program = { + arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + for arg_id in arg_id_to_arg} + + placeholder_args = [_get_f_placeholder_args_for_param_study(arg, iarg, + input_id_to_name_in_program, self.actx) + for iarg, arg in enumerate(args)] + output_template = self.f(*placeholder_args, + **{kw: _get_f_placeholder_args_for_param_study(arg, kw, + input_id_to_name_in_program, + self.actx) + for kw, arg in kwargs.items()}) + + self.actx._compile_trace_callback(self.f, "post_trace", output_template) + + if (not (is_array_container_type(output_template.__class__) + or isinstance(output_template, pt.Array))): + # TODO: We could possibly just short-circuit this interface if the + # returned type is a scalar. Not sure if it's worth it though. + raise NotImplementedError( + f"Function '{self.f.__name__}' to be compiled " + "did not return an array container or pt.Array," + f" but an instance of '{output_template.__class__}' instead.") + + def _as_dict_of_named_arrays(keys, ary): + name = "_pt_out_" + _ary_container_key_stringifier(keys) + output_id_to_name_in_program[keys] = name + dict_of_named_arrays[name] = ary + return ary + + rec_keyed_map_array_container(_as_dict_of_named_arrays, + output_template) + + input_shapes = {} + input_axes = {} + placeholder_name_to_parameter_studies: Dict[str, StudiesT] = {} + for key, val in arg_id_to_descr.items(): + if isinstance(val, LeafArrayDescriptor): + name = input_id_to_name_in_program[key] + for axis in arg_id_to_arg[key].axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + if name in placeholder_name_to_parameter_studies.keys(): + placeholder_name_to_parameter_studies[name].append(tags) + + else: + placeholder_name_to_parameter_studies[name] = tags + + breakpoint() + expand_map = ExpansionMapper(placeholder_name_to_parameter_studies) + # Get the dependencies + + sing_inst_outs = make_dict_of_named_arrays(dict_of_named_arrays) + + # Use the normal compiler now. + + compiled_func = self._dag_to_compiled_func(expand_map(sing_inst_outs), + # pt_dict_of_named_arrays, + input_id_to_name_in_program=input_id_to_name_in_program, + output_id_to_name_in_program=output_id_to_name_in_program, + output_template=output_template) + + breakpoint() + self.program_cache[arg_id_to_descr] = compiled_func + return compiled_func(arg_id_to_arg) + + +def _cut_to_single_instance_size(name, arg) -> Array: + """ + Helper to split a place holder into the base instance shape + if it is tagged with a `ParameterStudyAxisTag` + to ensure the survival of the information those tags will be converted + to temporary Array Tags of the same type. The placeholder will not + have the axes marked with a `ParameterStudyAxisTag` tag. + + We need to cut the extra axes off because we cannot assume + that the operators we use to build the single instance program + will understand what to do with the extra axes. + """ + ndim: int = len(arg.shape) + newshape: ShapeType = () + update_axes: AxesT = () + for i in range(ndim): + axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) + if not axis_tags: + update_axes = (*update_axes, arg.axes[i],) + newshape = (*newshape, arg.shape[i]) + + update_tags: FrozenSet[Tag] = arg.tags + + return make_placeholder(name, newshape, arg.dtype, axes=update_axes, + tags=update_tags) + + +def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): + """ + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. + Returns the placeholder version of an argument to + :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. + + Note this will modify the shape of the placeholder to + remove any parameter study axes until the trace + can be completed. + + They will be added back after the trace is complete. + """ + if np.isscalar(arg): + name = arg_id_to_name[(kw,)] + return make_placeholder(name, (), np.dtype(type(arg))) + elif isinstance(arg, Array): + name = arg_id_to_name[(kw,)] + # Transform the DAG to give metadata inference a chance to do its job + arg = _to_input_for_compiled(arg, actx) + return _cut_to_single_instance_size(name, arg) + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + name = arg_id_to_name[(kw, *keys)] + # Transform the DAG to give metadata inference a chance to do its job + ary = _to_input_for_compiled(ary, actx) + return _cut_to_single_instance_size(name, ary) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + def pack_for_parameter_study(actx: ArrayContext, study_name_tag_type: ParamStudyTagT, - newshape: Tuple[int, ...], *args: Array) -> Array: """ - Args is a list of variable names and the realized input data that needs - to be packed for a parameter study or uncertainty quantification. + Args is a list of realized input data that needs to be packed + for a parameter study or uncertainty quantification. - Args needs to be in the format - [v0, v1, v2, ..., vN] where N is the total number of instances you want to - try. Note these may be across multiple parameter studies on the same inputs. + We assume that each input data set has the same shape and + are safely castable to the same datatype. """ assert len(args) > 0 - assert len(args) == np.prod(newshape) orig_shape = args[0].shape out = actx.np.stack(args, axis=len(args[0].shape)) - outshape = *orig_shape, newshape - - # if len(newshape) > 1: - # # Reshape the object - # out = out.reshape(outshape) - for i in range(len(orig_shape), len(outshape)): - out = out.with_tagged_axis(i, [study_name_tag_type(i - len(orig_shape), - newshape[i-len(orig_shape)])]) + for i in range(len(orig_shape), len(out.shape)): + out = out.with_tagged_axis(i, [study_name_tag_type(len(args))]) return out @@ -140,6 +330,5 @@ def unpack_parameter_study(data: Array, out[study_count] = [data[the_slice]] if study_count in out.keys(): study_count += 1 - # yield data[the_slice] return out diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py deleted file mode 100644 index b945aae8..00000000 --- a/arraycontext/parameter_study/transform.py +++ /dev/null @@ -1,602 +0,0 @@ -from __future__ import annotations - - -""" -.. currentmodule:: arraycontext - -Compiling a Python callable (Internal) for multiple distinct instances of -execution. -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. automodule:: arraycontext.parameter_study -""" -__copyright__ = """ -Copyright (C) 2020-1 University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - Mapping, - Sequence, - Set, - Tuple, - Union, -) - -import numpy as np -import pymbolic.primitives as prim -from immutabledict import immutabledict - -import loopy as lp -import pytato as pt -from pytato.array import ( - Array, - AxesT, - Axis, - AxisPermutation, - Concatenate, - Einsum, - EinsumElementwiseAxis, - IndexBase, - IndexLambda, - Placeholder, - Reshape, - Roll, - ShapeType, - Stack, -) -from pytato.scalar_expr import IdentityMapper -from pytato.transform import CopyMapper -from pytools.tag import Tag, UniqueTag - -from arraycontext import PytatoPyOpenCLArrayContext -from arraycontext.container import is_array_container_type -from arraycontext.container.traversal import rec_keyed_map_array_container -from arraycontext.impl.pytato.compile import ( - LazilyPyOpenCLCompilingFunctionCaller, - LeafArrayDescriptor, - _ary_container_key_stringifier, - _get_arg_id_to_arg_and_arg_id_to_descr, - _to_input_for_compiled, -) - - -ArraysT = Tuple[Array, ...] - - -@dataclass(frozen=True) -class ParameterStudyAxisTag(UniqueTag): - """ - A tag for acting on axes of arrays. - To enable multiple parameter studies on the same variable name - specify a different axis number and potentially a different size. - - Currently does not allow multiple variables of different names to be in - the same parameter study. - """ - # user_param_study_tag: Tag - axis_num: int - axis_size: int - - -StudiesT = Tuple[ParameterStudyAxisTag, ...] - - -class ExpansionMapper(CopyMapper): - - def __init__(self, actual_input_shapes: Mapping[str, ShapeType], - actual_input_axes: Mapping[str, FrozenSet[Axis]]): - super().__init__() - self.actual_input_shapes = actual_input_shapes - self.actual_input_axes = actual_input_axes - - def single_predecessor_updates(self, curr_expr: Array, - new_expr: Array) -> Tuple[ShapeType, - AxesT]: - # Initialize with something for the typing. - shape_to_append: ShapeType = (-1,) - new_axes: AxesT = (Axis(tags=frozenset()),) - if curr_expr.shape == new_expr.shape: - return shape_to_append, new_axes - - # Now we may need to change. - for i in range(len(new_expr.axes)): - axis_tags = list(new_expr.axes[i].tags) - already_added = False - for _j, tag in enumerate(axis_tags): - # Should be relatively few tags on each axis $O(1)$. - if isinstance(tag, ParameterStudyAxisTag): - new_axes = *new_axes, new_expr.axes[i], - shape_to_append = *shape_to_append, new_expr.shape[i], - if already_added: - raise ValueError("An individual axis may only be " + - "tagged with one ParameterStudyAxisTag.") - already_added = True - - # Remove initialized extraneous data - return shape_to_append[1:], new_axes[1:] - - def map_roll(self, expr: Roll) -> Array: - new_array = self.rec(expr.array) - _, new_axes = self.single_predecessor_updates(expr.array, new_array) - return Roll(array=new_array, - shift=expr.shift, - axis=expr.axis, - axes=expr.axes + new_axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_axis_permutation(self, expr: AxisPermutation) -> Array: - new_array = self.rec(expr.array) - postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) - # Include the axes we are adding to the system. - axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) - for i in range(len(postpend_shape))]) - - return AxisPermutation(array=new_array, - axis_permutation=axis_permute, - axes=expr.axes + new_axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def _map_index_base(self, expr: IndexBase) -> Array: - breakpoint() - new_array = self.rec(expr.array) - _, new_axes = self.single_predecessor_updates(expr.array, new_array) - return type(expr)(new_array, - indices=self.rec_idx_or_size_tuple(expr.indices), - # May need to modify indices - axes=expr.axes + new_axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_reshape(self, expr: Reshape) -> Array: - new_array = self.rec(expr.array) - postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) - return Reshape(new_array, - newshape=self.rec_idx_or_size_tuple(expr.newshape + \ - postpend_shape), - order=expr.order, - axes=expr.axes + new_axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_placeholder(self, expr: Placeholder) -> Array: - # This is where we could introduce extra axes. - correct_shape = expr.shape - correct_axes = expr.axes - if expr.name in self.actual_input_shapes.keys(): - # We may need to update the size. - if expr.shape != self.actual_input_shapes[expr.name]: - correct_shape = self.actual_input_shapes[expr.name] - correct_axes = tuple(self.actual_input_axes[expr.name]) - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(correct_shape), - dtype=expr.dtype, - axes=correct_axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - # {{{ Operations with multiple predecessors. - - def _studies_from_multiple_pred(self, - new_arrays: ArraysT) -> Tuple[AxesT, - Set[ParameterStudyAxisTag], - Dict[Array, - StudiesT]]: - - new_axes_for_end: AxesT = () - cur_studies: Set[ParameterStudyAxisTag] = set() - studies_by_array: Dict[Array, StudiesT] = {} - - for _ind, array in enumerate(new_arrays): - for axis in array.axes: - axis_tags = axis.tags_of_type(ParameterStudyAxisTag) - if axis_tags: - axis_tags = list(axis_tags) - assert len(axis_tags) == 1 - if array in studies_by_array.keys(): - studies_by_array[array] = studies_by_array[array] + \ - (axis_tags[0],) - else: - studies_by_array[array] = (axis_tags[0],) - - if axis_tags[0] not in cur_studies: - cur_studies.add(axis_tags[0]) - new_axes_for_end = *new_axes_for_end, axis - - return new_axes_for_end, cur_studies, studies_by_array - - def map_stack(self, expr: Stack) -> Array: - new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - - return Stack(arrays=new_arrays, - axis=expr.axis, - axes=expr.axes + new_axes_for_end, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_concatenate(self, expr: Concatenate) -> Array: - new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - - return Concatenate(arrays=new_arrays, - axis=expr.axis, - axes=expr.axes + new_axes_for_end, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT, - AxesT]: - - one_inst_in_shape = expr.arrays[0].shape - new_arrays = tuple(self.rec(arr) for arr in expr.arrays) - - _, cur_studies, studies_by_array = self._studies_from_multiple_pred(new_arrays) - - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape_of_predecessors = one_inst_in_shape - new_axes = expr.axes - - for study in cur_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end of the bindings. - study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = *new_shape_of_predecessors, \ - (study.axis_size,) - new_axes = *new_axes, Axis(tags=frozenset((study,))), - # This assumes that the axis only has 1 tag, - # because there should be no dependence across instances. - - # This is going to be expensive. - - # Now we need to update the expressions. - # Now that we have the appropriate shape, - # we need to update the input arrays to match. - - cp_map = CopyMapper() - corrected_new_arrays: ArraysT = () - for _, array in enumerate(new_arrays): - tmp = cp_map(array) # Get a copy of the array. - if len(array.axes) < len(new_axes): - # We need to grow the array to the new size. - for study in cur_studies: - if study not in studies_by_array[array]: - build: ArraysT = tuple([cp_map(tmp) for - _ in range(study.axis_size)]) - tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=(*tmp.axes, Axis(tags=frozenset((study,)))), - tags=tmp.tags, - non_equality_tags=tmp.non_equality_tags) - elif len(array.axes) > len(new_axes): - raise ValueError("Input array is too big. " + - f"Expected at most: {len(new_axes)} " + - f"Found: {len(array.axes)} axes.") - - # Now we need to correct to the appropriate shape with an axis permutation. - # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(one_inst_in_shape))]) - - for _, axis in enumerate(tmp.axes): - axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) - if axis_tags: - assert len(axis_tags) == 1 - permute = *permute, study_to_axis_number[axis_tags[0]], - assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = *corrected_new_arrays, \ - AxisPermutation(tmp, permute, tags=tmp.tags, - axes=tmp.axes, - non_equality_tags=tmp.non_equality_tags), - - return corrected_new_arrays, new_axes - - def map_index_lambda(self, expr: IndexLambda) -> Array: - # Update bindings first. - new_bindings: Dict[str, Array] = {name: self.rec(bnd) - for name, bnd in - sorted(expr.bindings.items())} - - # Determine the new parameter studies that are being conducted. - from pytools import unique - - all_axis_tags: StudiesT = () - varname_to_studies: Dict[str, Dict[UniqueTag, bool]] = {} - for name, bnd in sorted(new_bindings.items()): - axis_tags_for_bnd: Set[Tag] = set() - varname_to_studies[name] = {} - for i in range(len(bnd.axes)): - axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) - for tag in axis_tags_for_bnd: - if isinstance(tag, ParameterStudyAxisTag): - # Defense - varname_to_studies[name][tag] = True - all_axis_tags = *all_axis_tags, tag, - - cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape = expr.shape - new_axes = expr.axes - - for study in cur_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end of the bindings. - study_to_axis_number[study] = len(new_shape) - new_shape = *new_shape, study.axis_size, - new_axes = *new_axes, Axis(tags=frozenset((study,))), - # This assumes that the axis only has 1 tag, - # because there should be no dependence across instances. - - # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies, - study_to_axis_number) - - return IndexLambda(expr=scalar_expr, - bindings=immutabledict(new_bindings), - shape=new_shape, - var_to_reduction_descr=expr.var_to_reduction_descr, - dtype=expr.dtype, - axes=new_axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_einsum(self, expr: Einsum) -> Array: - - new_arrays = tuple([self.rec(arg) for arg in expr.args]) - new_axes_for_end, cur_studies, _ = self._studies_from_multiple_pred(new_arrays) - - # Access Descriptors hold the Einsum notation. - new_access_descriptors = list(expr.access_descriptors) - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape = expr.shape - - for study in cur_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end. - study_to_axis_number[study] = len(new_shape) - new_shape = *new_shape, study.axis_size, - - for ind, array in enumerate(new_arrays): - for _, axis in enumerate(array.axes): - axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) - if axis_tags: - assert len(axis_tags) == 1 - new_access_descriptors[ind] = new_access_descriptors[ind] + \ - (EinsumElementwiseAxis(dim=study_to_axis_number[axis_tags[0]]),) - - return Einsum(tuple(new_access_descriptors), new_arrays, - axes=expr.axes + new_axes_for_end, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - # }}} Operations with multiple predecessors. - - -class ParamAxisExpander(IdentityMapper): - - def map_subscript(self, expr: prim.Subscript, - varname_to_studies: Mapping[str, - Mapping[ParameterStudyAxisTag, bool]], - study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): - # We know that we are not changing the variable that we are indexing into. - # This is stored in the aggregate member of the class Subscript. - - # We only need to modify the indexing which is stored in the index member. - name = expr.aggregate.name - if name in varname_to_studies.keys(): - # These are the single instance information. - index = self.rec(expr.index, varname_to_studies, - study_to_axis_number) - - new_vars: Tuple[prim.Variable, ...] = () - - for key, num in sorted(study_to_axis_number.items(), - key=lambda item: item[1]): - if key in varname_to_studies[name]: - new_vars = *new_vars, prim.Variable(f"_{num}"), - - if isinstance(index, tuple): - index = index + new_vars - else: - index = tuple(index) + new_vars - return type(expr)(aggregate=expr.aggregate, index=index) - return expr - - -# {{{ ParamStudyPytatoPyOpenCLArrayContext - - -class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): - """ - A derived class for PytatoPyOpenCLArrayContext updated for the - purpose of enabling parameter studies and uncertainty quantification. - - .. automethod:: __init__ - - .. automethod:: transform_dag - - .. automethod:: compile - """ - - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: - return ParamStudyLazyPyOpenCLFunctionCaller(self, f) - - def transform_loopy_program(self, - t_unit: lp.TranslationUnit) -> lp.TranslationUnit: - # Update in a subclass if you want. - return t_unit - -# }}} - - -class ParamStudyLazyPyOpenCLFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): - """ - Record a side-effect-free callable :attr:`f` which is initially designed for - to be called multiple times with different data. This class will update the - signature to allow :attr:`f` to be called once with the data for multiple - instances. - """ - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - """ - Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s - function application on *args*. - - Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, - it is compiled to a :mod:`pytato` DAG that would apply - :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` - with *args* in a lazy-sense. The intermediary pytato DAG for *args* is - memoized in *self*. - """ - arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( - args, kwargs) - - try: - compiled_f = self.program_cache[arg_id_to_descr] - except KeyError: - pass - else: - # On a cache hit we do not need to modify anything. - return compiled_f(arg_id_to_arg) - - dict_of_named_arrays = {} - output_id_to_name_in_program = {} - input_id_to_name_in_program = { - arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" - for arg_id in arg_id_to_arg} - - placeholder_args = [_get_f_placeholder_args_for_param_study(arg, iarg, - input_id_to_name_in_program, self.actx) - for iarg, arg in enumerate(args)] - output_template = self.f(*placeholder_args, - **{kw: _get_f_placeholder_args_for_param_study(arg, kw, - input_id_to_name_in_program, - self.actx) - for kw, arg in kwargs.items()}) - - self.actx._compile_trace_callback(self.f, "post_trace", output_template) - - if (not (is_array_container_type(output_template.__class__) - or isinstance(output_template, pt.Array))): - # TODO: We could possibly just short-circuit this interface if the - # returned type is a scalar. Not sure if it's worth it though. - raise NotImplementedError( - f"Function '{self.f.__name__}' to be compiled " - "did not return an array container or pt.Array," - f" but an instance of '{output_template.__class__}' instead.") - - def _as_dict_of_named_arrays(keys, ary): - name = "_pt_out_" + _ary_container_key_stringifier(keys) - output_id_to_name_in_program[keys] = name - dict_of_named_arrays[name] = ary - return ary - - rec_keyed_map_array_container(_as_dict_of_named_arrays, - output_template) - - input_shapes = {} - input_axes = {} - for key, val in arg_id_to_descr.items(): - if isinstance(val, LeafArrayDescriptor): - input_shapes[input_id_to_name_in_program[key]] = val.shape - input_axes[input_id_to_name_in_program[key]] = arg_id_to_arg[key].axes - expand_map = ExpansionMapper(input_shapes, input_axes) - # Get the dependencies - - sing_inst_outs = pt.make_dict_of_named_arrays(dict_of_named_arrays) - - # Use the normal compiler now. - - compiled_func = self._dag_to_compiled_func(expand_map(sing_inst_outs), - # pt_dict_of_named_arrays, - input_id_to_name_in_program=input_id_to_name_in_program, - output_id_to_name_in_program=output_id_to_name_in_program, - output_template=output_template) - - self.program_cache[arg_id_to_descr] = compiled_func - return compiled_func(arg_id_to_arg) - - -def _cut_if_in_param_study(name, arg) -> Array: - """ - Helper to split a place holder into the base instance shape - if it is tagged with a `ParameterStudyAxisTag` - to ensure the survival of the information those tags will be converted - to temporary Array Tags of the same type. The placeholder will not - have the axes marked with a `ParameterStudyAxisTag` tag. - """ - ndim: int = len(arg.shape) - newshape = [] - update_axes: AxesT = (Axis(tags=frozenset()),) - for i in range(ndim): - axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) - if not axis_tags: - update_axes = *update_axes, arg.axes[i], - newshape.append(arg.shape[i]) - # remove the first one that was placed there for typing. - update_axes = update_axes[1:] - update_tags: FrozenSet[Tag] = arg.tags - return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, - tags=update_tags) - - -def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): - """ - Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. - Returns the placeholder version of an argument to - :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. - - Note this will modify the shape of the placeholder to - remove any parameter study axes until the trace - can be completed. - - They will be added back after the trace is complete. - """ - if np.isscalar(arg): - name = arg_id_to_name[(kw,)] - return pt.make_placeholder(name, (), np.dtype(type(arg))) - elif isinstance(arg, pt.Array): - name = arg_id_to_name[(kw,)] - # Transform the DAG to give metadata inference a chance to do its job - arg = _to_input_for_compiled(arg, actx) - return _cut_if_in_param_study(name, arg) - elif is_array_container_type(arg.__class__): - def _rec_to_placeholder(keys, ary): - name = arg_id_to_name[(kw, *keys)] - # Transform the DAG to give metadata inference a chance to do its job - ary = _to_input_for_compiled(ary, actx) - return _cut_if_in_param_study(name, ary) - - return rec_keyed_map_array_container(_rec_to_placeholder, arg) - else: - raise NotImplementedError(type(arg)) diff --git a/examples/advection.py b/examples/advection.py index ab4405d5..ec99f478 100644 --- a/examples/advection.py +++ b/examples/advection.py @@ -18,23 +18,15 @@ queue = cl.CommandQueue(ctx) actx = ParamStudyPytatoPyOpenCLArrayContext(queue) - - @dataclass(frozen=True) -class ParameterStudyForX(ParameterStudyAxisTag): - pass +class ParamStudy1(ParameterStudyAxisTag): + """ + 1st parameter study. + """ -@dataclass(frozen=True) -class ParameterStudyForY(ParameterStudyAxisTag): - pass - def test_one_time_step_advection(): - from arraycontext.impl.pytato import _BasePytatoArrayContext - if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): - pytest.skip("only parameter study array contexts are supported") - import numpy as np seed = 12345 rng = np.random.default_rng(seed) @@ -55,9 +47,9 @@ def test_one_time_step_advection(): ht = 0.0001 hx = 0.005 - inds = actx.np.arange(base_shape, dtype=int) - Kp1 = actx.np.roll(inds, -1) - Km1 = actx.np.roll(inds, 1) + inds = np.arange(base_shape, dtype=int) + Kp1 = actx.from_numpy(np.roll(inds, -1)) + Km1 = actx.from_numpy(np.roll(inds, 1)) def rhs(fields, wave_speed): # 2nd order in space finite difference @@ -65,17 +57,25 @@ def rhs(fields, wave_speed): (fields[Kp1] - fields[Km1]) pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) + breakpoint() assert pack_x.shape == (75,4) pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0,y1, y2,y3) + breakpoint() assert pack_y.shape == (1,4) compiled_rhs = actx.compile(rhs) + breakpoint() output = compiled_rhs(pack_x, pack_y) - + breakpoint() assert output.shape(75,4) output_x = unpack_parameter_study(output, ParamStudy1) assert len(output_x) == 1 # Only 1 study associated with this variable. assert len(output_x[0]) == 4 # 4 inputs for the parameter study. + + print("All checks passed") + +# Call it. +test_one_time_step_advection() diff --git a/examples/parameter_study.py b/examples/parameter_study.py index 83acecad..e4c8d766 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -7,14 +7,12 @@ from arraycontext.parameter_study import ( pack_for_parameter_study, unpack_parameter_study, -) -from arraycontext.parameter_study.transform import ( - ParameterStudyAxisTag, ParamStudyPytatoPyOpenCLArrayContext, + ParameterStudyAxisTag, ) -ctx = cl.create_some_context(interactive=False) +ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) actx = ParamStudyPytatoPyOpenCLArrayContext(queue) @@ -35,7 +33,10 @@ # Eq: z = x + y # Assumptions: x and y are undergoing independent parameter studies. def rhs(param1, param2): - return param1 + param2 + import pytato as pt + return pt.matmul(param1, param2.T) + return pt.stack([param1[0], param2[10]], axis=0) + return param1[0] + param2[10] @dataclass(frozen=True) From 478fa3b79be366f944bedb4a34cb4a9c847b326c Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 16:33:13 -0500 Subject: [PATCH 18/25] Update imports to match the location. --- test/test_pytato_parameter_study.py | 95 +++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/test/test_pytato_parameter_study.py b/test/test_pytato_parameter_study.py index 5f112a78..0761523c 100644 --- a/test/test_pytato_parameter_study.py +++ b/test/test_pytato_parameter_study.py @@ -26,21 +26,15 @@ import pytest -from pytools.tag import Tag - from arraycontext import ( - PytatoPyOpenCLArrayContext, pytest_generate_tests_for_array_contexts, ) -from arraycontext.parameter_study.transform import ( - ParamStudyPytatoPyOpenCLArrayContext, - ParameterStudyAxisTag -) from arraycontext.parameter_study import ( - pack_for_parameter_study, - unpack_parameter_study, + pack_for_parameter_study, + unpack_parameter_study, + ParameterStudyAxisTag, + ParamStudyPytatoPyOpenCLArrayContext, ) - from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory @@ -78,6 +72,7 @@ class ParamStudy1(ParameterStudyAxisTag): 1st parameter study. """ + class ParamStudy2(ParameterStudyAxisTag): """ 2bd parameter study. @@ -91,7 +86,6 @@ def test_pack_for_parameter_study(actx_factory): actx = actx_factory() - from arraycontext.impl.pytato import _BasePytatoArrayContext if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): pytest.skip("only parameter study array contexts are supported") @@ -104,7 +98,6 @@ def test_pack_for_parameter_study(actx_factory): x1 = actx.from_numpy(rng.random(base_shape)) x2 = actx.from_numpy(rng.random(base_shape)) x3 = actx.from_numpy(rng.random(base_shape)) - y0 = actx.from_numpy(rng.random(base_shape)) y1 = actx.from_numpy(rng.random(base_shape)) @@ -112,15 +105,15 @@ def test_pack_for_parameter_study(actx_factory): y3 = actx.from_numpy(rng.random(base_shape)) y4 = actx.from_numpy(rng.random(base_shape)) - def rhs(a,b): + def rhs(a, b): return a + b # Adding to the end. pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) - assert pack_x.shape == (15,5,4) + assert pack_x.shape == (15, 5, 4) - pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0,y1, y2,y3,y4) - assert pack_y.shape == (15,5,5) + pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0, y1, y2, y3, y4) + assert pack_y.shape == (15, 5, 5) for i in range(3): axis_tags = pack_x.axes[i].tags_of_type(ParamStudy1) @@ -131,11 +124,11 @@ def rhs(a,b): assert not axis_tags assert not second_tags + def test_unpack_parameter_study(actx_factory): actx = actx_factory() - from arraycontext.impl.pytato import _BasePytatoArrayContext if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): pytest.skip("only parameter study array contexts are supported") @@ -148,7 +141,6 @@ def test_unpack_parameter_study(actx_factory): x1 = actx.from_numpy(rng.random(base_shape)) x2 = actx.from_numpy(rng.random(base_shape)) x3 = actx.from_numpy(rng.random(base_shape)) - y0 = actx.from_numpy(rng.random(base_shape)) y1 = actx.from_numpy(rng.random(base_shape)) @@ -156,37 +148,86 @@ def test_unpack_parameter_study(actx_factory): y3 = actx.from_numpy(rng.random(base_shape)) y4 = actx.from_numpy(rng.random(base_shape)) - def rhs(a,b): + def rhs(a, b): return a + b pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) - assert pack_x.shape == (15,5,4) + assert pack_x.shape == (15, 5, 4) - pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0,y1, y2,y3,y4) - assert pack_y.shape == (15,5,5) + pack_y = pack_for_parameter_study(actx, ParamStudy2, (5,), y0, y1, y2, y3, y4) + assert pack_y.shape == (15, 5, 5) compiled_rhs = actx.compile(rhs) output = compiled_rhs(pack_x, pack_y) - assert output.shape(15,5,4,5) + assert output.shape(15, 5, 4, 5) output_x = unpack_parameter_study(output, ParamStudy1) assert len(output_x) == 1 # Only 1 study associated with this variable. - assert len(output_x[0]) == 4 # 4 inputs for the parameter study. + assert len(output_x[0]) == 4 # 4 inputs for the parameter study. for i in range(len(output_x[0])): assert output_x[0][i].shape == (5, 15, 5) - output_y = unpack_parameter_study(output, ParamStudy2) assert len(output_y) == 1 # Only 1 study associated with this variable. - assert len(output_y[0]) == 5 # 5 inputs for the parameter study. + assert len(output_y[0]) == 5 # 5 inputs for the parameter study. for i in range(len(output_y[0])): assert output_y[0][i].shape == (4, 15, 5) -# }}} +def test_one_time_step_advection(actx_factory): + + actx = actx_factory() + + if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): + pytest.skip("only parameter study array contexts are supported") + + import numpy as np + seed = 12345 + rng = np.random.default_rng(seed) + base_shape = np.prod((15, 5)) + x0 = actx.from_numpy(rng.random(base_shape)) + x1 = actx.from_numpy(rng.random(base_shape)) + x2 = actx.from_numpy(rng.random(base_shape)) + x3 = actx.from_numpy(rng.random(base_shape)) + + speed_shape = (1,) + y0 = actx.from_numpy(rng.random(speed_shape)) + y1 = actx.from_numpy(rng.random(speed_shape)) + y2 = actx.from_numpy(rng.random(speed_shape)) + y3 = actx.from_numpy(rng.random(speed_shape)) + + ht = 0.0001 + hx = 0.005 + inds = actx.arange(base_shape) + kp1 = actx.roll(inds, -1) + km1 = actx.roll(inds, 1) + + def rhs(fields, wave_speed): + # 2nd order in space finite difference + return fields + wave_speed * (-1) * (ht / (2 * hx)) * \ + (fields[kp1] - fields[km1]) + + pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) + assert pack_x.shape == (75, 4) + + pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0, y1, y2, y3) + assert pack_y.shape == (1, 4) + + compiled_rhs = actx.compile(rhs) + + output = compiled_rhs(pack_x, pack_y) + + assert output.shape(75, 4) + + output_x = unpack_parameter_study(output, ParamStudy1) + assert len(output_x) == 1 # Only 1 study associated with this variable. + assert len(output_x[0]) == 4 # 4 inputs for the parameter study. + + +# }}} if __name__ == "__main__": From aed5080b9d7272350c79c3b95f974a3a0647990f Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 16:51:41 -0500 Subject: [PATCH 19/25] Update. --- arraycontext/__init__.py | 5 +- arraycontext/impl/pytato/__init__.py | 5 +- arraycontext/parameter_study/__init__.py | 68 ++++++++++++------------ arraycontext/pytest.py | 2 + examples/advection.py | 22 ++++---- examples/parameter_study.py | 4 +- test/test_pytato_parameter_study.py | 4 +- 7 files changed, 57 insertions(+), 53 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 4ac793c8..0ec3b148 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -81,6 +81,7 @@ from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program +from .parameter_study import pack_for_parameter_study, unpack_parameter_study from .pytest import ( PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory, @@ -88,7 +89,7 @@ pytest_generate_tests_for_pyopencl_array_context, ) from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag -from .parameter_study import pack_for_parameter_study, unpack_parameter_study + __all__ = ( "Array", @@ -132,6 +133,7 @@ "multimap_reduce_array_container", "multimapped_over_array_containers", "outer", + "pack_for_parameter_study", "pytest_generate_tests_for_array_contexts", "pytest_generate_tests_for_pyopencl_array_context", "rec_map_array_container", @@ -145,6 +147,7 @@ "thaw", "to_numpy", "unflatten", + "unpack_parameter_study", "with_array_context", "with_container_arithmetic" ) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 7142a4db..48c8489b 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -59,13 +59,12 @@ import numpy as np from pytools import memoize_method -from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag +from pytools.tag import Tag, ToTagSetConvertible, UniqueTag as UniqueTag, normalize_tags from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike from arraycontext.metadata import NameHint -from dataclasses import dataclass if TYPE_CHECKING: import pyopencl as cl @@ -703,8 +702,6 @@ def clone(self): # }}} - - # {{{ PytatoJAXArrayContext class PytatoJAXArrayContext(_BasePytatoArrayContext): diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 6b3ea9d6..f03bdff1 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -1,5 +1,3 @@ -from future import __annotations__ - """ .. currentmodule:: arraycontext @@ -48,45 +46,51 @@ THE SOFTWARE. """ -import sys from typing import ( - TYPE_CHECKING, Any, Callable, - Dict, - List, Mapping, - Tuple, Type, ) import numpy as np import loopy as lp -from pytato.array import (Array, make_placeholder as make_placeholder, - make_dict_of_named_arrays) - -from pytato.transform.parameter_study import ParameterStudyAxisTag +from pytato.array import ( + Array, + AxesT, + ShapeType, + make_dict_of_named_arrays, + make_placeholder as make_placeholder, +) +from pytato.transform.parameter_study import ( + ExpansionMapper, + ParameterStudyAxisTag, +) from pytools.tag import Tag, UniqueTag as UniqueTag -from arraycontext.context import ArrayContext -from arraycontext.container import ArrayContainer, is_array_container_type +from arraycontext.container import ( + ArrayContainer as ArrayContainer, + is_array_container_type, +) from arraycontext.container.traversal import rec_keyed_map_array_container -from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext -from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller, - _to_input_for_compiled) +from arraycontext.context import ArrayContext +from arraycontext.impl.pytato import ( + PytatoPyOpenCLArrayContext, + _get_arg_id_to_arg_and_arg_id_to_descr, +) +from arraycontext.impl.pytato.compile import ( + LazilyPyOpenCLCompilingFunctionCaller, + LeafArrayDescriptor, + _ary_container_key_stringifier, + _to_input_for_compiled, +) -ArraysT = Tuple[Array, ...] -StudiesT = Tuple[ParameterStudyAxisTag, ...] +ArraysT = tuple[Array, ...] +StudiesT = tuple[ParameterStudyAxisTag, ...] ParamStudyTagT = Type[ParameterStudyAxisTag] -if TYPE_CHECKING: - import pyopencl as cl - import pytato as pytato - -if getattr(sys, "_BUILDING_SPHINX_DOCS", False): - import pyopencl as cl import logging @@ -168,7 +172,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: self.actx._compile_trace_callback(self.f, "post_trace", output_template) if (not (is_array_container_type(output_template.__class__) - or isinstance(output_template, pt.Array))): + or isinstance(output_template, Array))): # TODO: We could possibly just short-circuit this interface if the # returned type is a scalar. Not sure if it's worth it though. raise NotImplementedError( @@ -185,9 +189,7 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - input_shapes = {} - input_axes = {} - placeholder_name_to_parameter_studies: Dict[str, StudiesT] = {} + placeholder_name_to_parameter_studies: dict[str, StudiesT] = {} for key, val in arg_id_to_descr.items(): if isinstance(val, LeafArrayDescriptor): name = input_id_to_name_in_program[key] @@ -240,7 +242,7 @@ def _cut_to_single_instance_size(name, arg) -> Array: update_axes = (*update_axes, arg.axes[i],) newshape = (*newshape, arg.shape[i]) - update_tags: FrozenSet[Tag] = arg.tags + update_tags: frozenset[Tag] = arg.tags return make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) @@ -282,7 +284,7 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag_type: ParamStudyTagT, *args: Array) -> Array: """ - Args is a list of realized input data that needs to be packed + Args is a list of realized input data that needs to be packed for a parameter study or uncertainty quantification. We assume that each input data set has the same shape and @@ -301,7 +303,7 @@ def pack_for_parameter_study(actx: ArrayContext, def unpack_parameter_study(data: Array, study_name_tag_type: ParamStudyTagT) -> Mapping[int, - List[Array]]: + list[Array]]: """ Split the data array along the axes which vary according to a ParameterStudyAxisTag whose name tag is an instance study_name_tag_type. @@ -311,7 +313,7 @@ def unpack_parameter_study(data: Array, """ ndim: int = len(data.shape) - out: Dict[int, List[Array]] = {} + out: dict[int, list[Array]] = {} study_count = 0 for i in range(ndim): @@ -320,7 +322,7 @@ def unpack_parameter_study(data: Array, # Now we need to split this data. breakpoint() for j in range(data.shape[i]): - tmp: List[Any] = [slice(None)] * ndim + tmp: list[Any] = [slice(None)] * ndim tmp[i] = j the_slice = tuple(tmp) # Needs to be a tuple of slices not list of slices. diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index d3d719e5..c05d8d31 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -62,6 +62,7 @@ def __init__(self, device): @classmethod def is_available(cls) -> bool: + return False try: import pyopencl # noqa: F401 return True @@ -133,6 +134,7 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): @classmethod def is_available(cls) -> bool: + return True try: import pyopencl # noqa: F401 import pytato # noqa: F401 diff --git a/examples/advection.py b/examples/advection.py index ec99f478..2f572169 100644 --- a/examples/advection.py +++ b/examples/advection.py @@ -18,6 +18,7 @@ queue = cl.CommandQueue(ctx) actx = ParamStudyPytatoPyOpenCLArrayContext(queue) + @dataclass(frozen=True) class ParamStudy1(ParameterStudyAxisTag): """ @@ -27,7 +28,6 @@ class ParamStudy1(ParameterStudyAxisTag): def test_one_time_step_advection(): - import numpy as np seed = 12345 rng = np.random.default_rng(seed) @@ -36,7 +36,6 @@ def test_one_time_step_advection(): x1 = actx.from_numpy(rng.random(base_shape)) x2 = actx.from_numpy(rng.random(base_shape)) x3 = actx.from_numpy(rng.random(base_shape)) - speed_shape = (1,) y0 = actx.from_numpy(rng.random(speed_shape)) @@ -44,38 +43,39 @@ def test_one_time_step_advection(): y2 = actx.from_numpy(rng.random(speed_shape)) y3 = actx.from_numpy(rng.random(speed_shape)) - ht = 0.0001 hx = 0.005 inds = np.arange(base_shape, dtype=int) - Kp1 = actx.from_numpy(np.roll(inds, -1)) - Km1 = actx.from_numpy(np.roll(inds, 1)) + kp1 = actx.from_numpy(np.roll(inds, -1)) + km1 = actx.from_numpy(np.roll(inds, 1)) def rhs(fields, wave_speed): # 2nd order in space finite difference return fields + wave_speed * (-1) * (ht / (2 * hx)) * \ - (fields[Kp1] - fields[Km1]) + (fields[kp1] - fields[km1]) pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) breakpoint() - assert pack_x.shape == (75,4) + assert pack_x.shape == (75, 4) - pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0,y1, y2,y3) + pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0, y1, y2, y3) breakpoint() - assert pack_y.shape == (1,4) + assert pack_y.shape == (1, 4) compiled_rhs = actx.compile(rhs) breakpoint() output = compiled_rhs(pack_x, pack_y) breakpoint() - assert output.shape(75,4) + assert output.shape(75, 4) output_x = unpack_parameter_study(output, ParamStudy1) assert len(output_x) == 1 # Only 1 study associated with this variable. - assert len(output_x[0]) == 4 # 4 inputs for the parameter study. + assert len(output_x[0]) == 4 # 4 inputs for the parameter study. print("All checks passed") # Call it. + + test_one_time_step_advection() diff --git a/examples/parameter_study.py b/examples/parameter_study.py index e4c8d766..f6aac821 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -5,10 +5,10 @@ import pyopencl as cl from arraycontext.parameter_study import ( + ParameterStudyAxisTag, + ParamStudyPytatoPyOpenCLArrayContext, pack_for_parameter_study, unpack_parameter_study, - ParamStudyPytatoPyOpenCLArrayContext, - ParameterStudyAxisTag, ) diff --git a/test/test_pytato_parameter_study.py b/test/test_pytato_parameter_study.py index 0761523c..5f890ebb 100644 --- a/test/test_pytato_parameter_study.py +++ b/test/test_pytato_parameter_study.py @@ -30,10 +30,10 @@ pytest_generate_tests_for_array_contexts, ) from arraycontext.parameter_study import ( - pack_for_parameter_study, - unpack_parameter_study, ParameterStudyAxisTag, ParamStudyPytatoPyOpenCLArrayContext, + pack_for_parameter_study, + unpack_parameter_study, ) from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory From 470a80c7952bbb50706f2f6820585c09ae487e92 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 16:57:59 -0500 Subject: [PATCH 20/25] Fix pylint errors. --- arraycontext/parameter_study/__init__.py | 2 +- examples/advection.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index f03bdff1..61041852 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -77,12 +77,12 @@ from arraycontext.context import ArrayContext from arraycontext.impl.pytato import ( PytatoPyOpenCLArrayContext, - _get_arg_id_to_arg_and_arg_id_to_descr, ) from arraycontext.impl.pytato.compile import ( LazilyPyOpenCLCompilingFunctionCaller, LeafArrayDescriptor, _ary_container_key_stringifier, + _get_arg_id_to_arg_and_arg_id_to_descr, _to_input_for_compiled, ) diff --git a/examples/advection.py b/examples/advection.py index 2f572169..9feed6e9 100644 --- a/examples/advection.py +++ b/examples/advection.py @@ -5,12 +5,10 @@ import pyopencl as cl from arraycontext.parameter_study import ( - pack_for_parameter_study, - unpack_parameter_study, -) -from arraycontext.parameter_study.transform import ( ParameterStudyAxisTag, ParamStudyPytatoPyOpenCLArrayContext, + pack_for_parameter_study, + unpack_parameter_study, ) From 65a9e59879d57c515ddd5ca000c5bcfe3dd4d1b1 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 8 Aug 2024 09:28:54 -0500 Subject: [PATCH 21/25] Add in asserts to confirm that multiple single instance programs in sequence result in the same numerical values as the one multiple instance execution. --- examples/advection.py | 35 +++++++------------ examples/parameter_study.py | 69 +++++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 44 deletions(-) diff --git a/examples/advection.py b/examples/advection.py index 9feed6e9..5d8f0496 100644 --- a/examples/advection.py +++ b/examples/advection.py @@ -31,15 +31,7 @@ def test_one_time_step_advection(): base_shape = np.prod((15, 5)) x0 = actx.from_numpy(rng.random(base_shape)) - x1 = actx.from_numpy(rng.random(base_shape)) - x2 = actx.from_numpy(rng.random(base_shape)) - x3 = actx.from_numpy(rng.random(base_shape)) - speed_shape = (1,) - y0 = actx.from_numpy(rng.random(speed_shape)) - y1 = actx.from_numpy(rng.random(speed_shape)) - y2 = actx.from_numpy(rng.random(speed_shape)) - y3 = actx.from_numpy(rng.random(speed_shape)) ht = 0.0001 hx = 0.005 @@ -52,24 +44,23 @@ def rhs(fields, wave_speed): return fields + wave_speed * (-1) * (ht / (2 * hx)) * \ (fields[kp1] - fields[km1]) - pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) - breakpoint() - assert pack_x.shape == (75, 4) - - pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0, y1, y2, y3) - breakpoint() - assert pack_y.shape == (1, 4) + wave_speeds = [actx.from_numpy(np.random.random(1)) for _ in range(255)] + print(type(wave_speeds[0])) + packed_speeds = pack_for_parameter_study(actx, ParamStudy1, *wave_speeds) compiled_rhs = actx.compile(rhs) - breakpoint() - output = compiled_rhs(pack_x, pack_y) - breakpoint() - assert output.shape(75, 4) + output = compiled_rhs(x0, packed_speeds) + output = actx.freeze(output) + + expanded_output = actx.to_numpy(output).T + + # Now for all the single values. + for idx in range(len(wave_speeds)): + out = compiled_rhs(x0, wave_speeds[idx]) + out = actx.freeze(out) + assert np.allclose(expanded_output[idx], actx.to_numpy(out)) - output_x = unpack_parameter_study(output, ParamStudy1) - assert len(output_x) == 1 # Only 1 study associated with this variable. - assert len(output_x[0]) == 4 # 4 inputs for the parameter study. print("All checks passed") diff --git a/examples/parameter_study.py b/examples/parameter_study.py index f6aac821..f9fa8ecf 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -30,14 +30,11 @@ y3 = actx.from_numpy(rng.random(base_shape)) -# Eq: z = x + y +# Eq: z = x @ y.T # Assumptions: x and y are undergoing independent parameter studies. +# x and y are matrices such that x @ y.T works in the single instance case. def rhs(param1, param2): - import pytato as pt - return pt.matmul(param1, param2.T) - return pt.stack([param1[0], param2[10]], axis=0) - return param1[0] + param2[10] - + return param1 @ param2.T @dataclass(frozen=True) class ParameterStudyForX(ParameterStudyAxisTag): @@ -50,9 +47,8 @@ class ParameterStudyForY(ParameterStudyAxisTag): # Pack a parameter study of 3 instances for x and and 4 instances for y. - -packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2) -packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3) +packx = pack_for_parameter_study(actx, ParameterStudyForX, x, x1, x2) +packy = pack_for_parameter_study(actx, ParameterStudyForY, y, y1, y2, y3) compiled_rhs = actx.compile(rhs) # Build the function caller @@ -60,16 +56,45 @@ class ParameterStudyForY(ParameterStudyAxisTag): # then converts it to a program which takes our multiple instances of `x` and `y`. output = compiled_rhs(packx, packy) output_2 = compiled_rhs(x, y) -breakpoint() - -assert output.shape == (15, 5, 3, 4) # Distinct parameter studies. - -output_x = unpack_parameter_study(output, ParameterStudyForX) -output_y = unpack_parameter_study(output, ParameterStudyForY) -assert len(output_x) == 1 # Number of parameter studies involving "x" -assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study -# All outputs across every other parameter study. -assert output_x[0][0].shape == (15, 5, 4) -assert len(output_y) == 1 -assert len(output_y[0]) == 4 -assert output_y[0][0].shape == (15, 5, 3) + +numpy_output = actx.to_numpy(output) + +assert numpy_output.shape == (15, 15, 3, 4) + +out = actx.to_numpy(compiled_rhs(x, y)) +assert np.allclose(numpy_output[..., 0, 0], out) + +out = actx.to_numpy(compiled_rhs(x, y1)) +assert np.allclose(numpy_output[..., 0, 1], out) + +out = actx.to_numpy(compiled_rhs(x, y2)) +assert np.allclose(numpy_output[..., 0, 2], out) + +out = actx.to_numpy(compiled_rhs(x, y3)) +assert np.allclose(numpy_output[..., 0, 3], out) + +out = actx.to_numpy(compiled_rhs(x1, y)) +assert np.allclose(numpy_output[..., 1, 0], out) + +out = actx.to_numpy(compiled_rhs(x1, y1)) +assert np.allclose(numpy_output[..., 1, 1], out) + +out = actx.to_numpy(compiled_rhs(x1, y2)) +assert np.allclose(numpy_output[..., 1, 2], out) + +out = actx.to_numpy(compiled_rhs(x1, y3)) +assert np.allclose(numpy_output[..., 1, 3], out) + +out = actx.to_numpy(compiled_rhs(x2, y)) +assert np.allclose(numpy_output[..., 2, 0], out) + +out = actx.to_numpy(compiled_rhs(x2, y1)) +assert np.allclose(numpy_output[..., 2, 1], out) + +out = actx.to_numpy(compiled_rhs(x2, y2)) +assert np.allclose(numpy_output[..., 2, 2], out) + +out = actx.to_numpy(compiled_rhs(x2, y3)) +assert np.allclose(numpy_output[..., 2, 3], out) + +print("All tests passed!") From c0a5fc94eb70c4776c70d4fc2c1ab74e12c9a56e Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 14:58:02 -0500 Subject: [PATCH 22/25] Implement packing for array containers. --- arraycontext/parameter_study/__init__.py | 59 ++++++++++++++++++------ 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index 61041852..af1b3f3d 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -64,17 +64,27 @@ make_placeholder as make_placeholder, ) from pytato.transform.parameter_study import ( - ExpansionMapper, ParameterStudyAxisTag, + ParameterStudyVectorizer, ) from pytools.tag import Tag, UniqueTag as UniqueTag +from arraycontext import ( + get_container_context_recursively as get_container_context_recursively, + rec_map_array_container as rec_map_array_container, + rec_multimap_array_container, +) from arraycontext.container import ( ArrayContainer as ArrayContainer, + deserialize_container as deserialize_container, is_array_container_type, + serialize_container as serialize_container, ) from arraycontext.container.traversal import rec_keyed_map_array_container -from arraycontext.context import ArrayContext +from arraycontext.context import ( + ArrayContext, + ArrayOrContainerT, +) from arraycontext.impl.pytato import ( PytatoPyOpenCLArrayContext, ) @@ -154,6 +164,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: # On a cache hit we do not need to modify anything. return compiled_f(arg_id_to_arg) + with open("calls_to_compile", "a+") as my_file: + my_file.write(str(arg_id_to_descr)) + my_file.write("\n") + dict_of_named_arrays = {} output_id_to_name_in_program = {} input_id_to_name_in_program = { @@ -163,6 +177,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: placeholder_args = [_get_f_placeholder_args_for_param_study(arg, iarg, input_id_to_name_in_program, self.actx) for iarg, arg in enumerate(args)] + breakpoint() output_template = self.f(*placeholder_args, **{kw: _get_f_placeholder_args_for_param_study(arg, kw, input_id_to_name_in_program, @@ -202,21 +217,18 @@ def _as_dict_of_named_arrays(keys, ary): else: placeholder_name_to_parameter_studies[name] = tags - breakpoint() - expand_map = ExpansionMapper(placeholder_name_to_parameter_studies) + vectorize = ParameterStudyVectorizer(placeholder_name_to_parameter_studies) # Get the dependencies - sing_inst_outs = make_dict_of_named_arrays(dict_of_named_arrays) # Use the normal compiler now. - compiled_func = self._dag_to_compiled_func(expand_map(sing_inst_outs), + compiled_func = self._dag_to_compiled_func(vectorize(sing_inst_outs), # pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) - breakpoint() self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) @@ -282,7 +294,7 @@ def _rec_to_placeholder(keys, ary): def pack_for_parameter_study(actx: ArrayContext, study_name_tag_type: ParamStudyTagT, - *args: Array) -> Array: + *args: ArrayOrContainerT) -> ArrayOrContainerT: """ Args is a list of realized input data that needs to be packed for a parameter study or uncertainty quantification. @@ -293,12 +305,32 @@ def pack_for_parameter_study(actx: ArrayContext, assert len(args) > 0 - orig_shape = args[0].shape - out = actx.np.stack(args, axis=len(args[0].shape)) + def recursive_stack(*args: Array) -> Array: + assert len(args) > 0 - for i in range(len(orig_shape), len(out.shape)): - out = out.with_tagged_axis(i, [study_name_tag_type(len(args))]) - return out + for val in args: + assert not is_array_container_type(type(val)) + assert isinstance(val, Array) + + orig_shape = args[0].shape + out = actx.np.stack(args, axis=len(orig_shape)) + out = out.with_tagged_axis(len(orig_shape), [study_name_tag_type(len(args))]) + + # We have added a new axis. + assert len(orig_shape) + 1 == len(out.shape) + # Assert that it has been tagged. + assert out.axes[-1].tags_of_type(study_name_tag_type) + + return out + + if is_array_container_type(type(args[0])): + # Need to deal with this as a container. + # assert isinstance(get_container_context_recursively(args[0]), type(actx)) + # assert isinstance(actx, get_container_context_recursively(args[0])) + + return rec_multimap_array_container(recursive_stack, *args) + + return recursive_stack(*args) def unpack_parameter_study(data: Array, @@ -320,7 +352,6 @@ def unpack_parameter_study(data: Array, axis_tags = data.axes[i].tags_of_type(study_name_tag_type) if axis_tags: # Now we need to split this data. - breakpoint() for j in range(data.shape[i]): tmp: list[Any] = [slice(None)] * ndim tmp[i] = j From 28e7b7017fdc3499465c2bbab53b5e51c8f28143 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 15:28:15 -0500 Subject: [PATCH 23/25] Add the requirement to use the correct pytato branch. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a4cb4025..96ba927f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/nkoskelo/pytato.git#egg=pytato@parameter_study From 9357fc9539f77303fed97f276a17d7dbb147a553 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 15:34:42 -0500 Subject: [PATCH 24/25] Update the examples. --- examples/advection.py | 8 ++------ examples/parameter_study.py | 3 ++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/advection.py b/examples/advection.py index 5d8f0496..937f91cc 100644 --- a/examples/advection.py +++ b/examples/advection.py @@ -8,7 +8,6 @@ ParameterStudyAxisTag, ParamStudyPytatoPyOpenCLArrayContext, pack_for_parameter_study, - unpack_parameter_study, ) @@ -32,7 +31,6 @@ def test_one_time_step_advection(): base_shape = np.prod((15, 5)) x0 = actx.from_numpy(rng.random(base_shape)) - ht = 0.0001 hx = 0.005 inds = np.arange(base_shape, dtype=int) @@ -44,15 +42,14 @@ def rhs(fields, wave_speed): return fields + wave_speed * (-1) * (ht / (2 * hx)) * \ (fields[kp1] - fields[km1]) - wave_speeds = [actx.from_numpy(np.random.random(1)) for _ in range(255)] - print(type(wave_speeds[0])) + wave_speeds = [actx.from_numpy(rng.random(1)) for _ in range(255)] packed_speeds = pack_for_parameter_study(actx, ParamStudy1, *wave_speeds) compiled_rhs = actx.compile(rhs) output = compiled_rhs(x0, packed_speeds) output = actx.freeze(output) - + expanded_output = actx.to_numpy(output).T # Now for all the single values. @@ -61,7 +58,6 @@ def rhs(fields, wave_speed): out = actx.freeze(out) assert np.allclose(expanded_output[idx], actx.to_numpy(out)) - print("All checks passed") # Call it. diff --git a/examples/parameter_study.py b/examples/parameter_study.py index f9fa8ecf..5b5f5ad6 100644 --- a/examples/parameter_study.py +++ b/examples/parameter_study.py @@ -8,7 +8,6 @@ ParameterStudyAxisTag, ParamStudyPytatoPyOpenCLArrayContext, pack_for_parameter_study, - unpack_parameter_study, ) @@ -36,6 +35,7 @@ def rhs(param1, param2): return param1 @ param2.T + @dataclass(frozen=True) class ParameterStudyForX(ParameterStudyAxisTag): pass @@ -47,6 +47,7 @@ class ParameterStudyForY(ParameterStudyAxisTag): # Pack a parameter study of 3 instances for x and and 4 instances for y. + packx = pack_for_parameter_study(actx, ParameterStudyForX, x, x1, x2) packy = pack_for_parameter_study(actx, ParameterStudyForY, y, y1, y2, y3) From ccfc9bc08feb8010c3821449a26fcd3c0f12ded7 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 21 Aug 2024 18:26:35 -0500 Subject: [PATCH 25/25] Trying to get the unpack to work so that you only need to index in with the parameter study and then you get an Array or Container of the same type as if you had done the single instance program. --- arraycontext/parameter_study/__init__.py | 154 +++++++++++++---------- 1 file changed, 87 insertions(+), 67 deletions(-) diff --git a/arraycontext/parameter_study/__init__.py b/arraycontext/parameter_study/__init__.py index af1b3f3d..eace93d1 100644 --- a/arraycontext/parameter_study/__init__.py +++ b/arraycontext/parameter_study/__init__.py @@ -1,12 +1,9 @@ - -""" +__doc__ = """ .. currentmodule:: arraycontext -A :mod:`pytato`-based array context defers the evaluation of an array until its -frozen. The execution contexts for the evaluations are specific to an -:class:`~arraycontext.ArrayContext` type. For ex. -:class:`~arraycontext.ParamStudyPytatoPyOpenCLArrayContext` -uses :mod:`pyopencl` to JIT-compile and execute the array expressions. +A parameter study array context allows a user to pass packed input into his or her single instance program. +These array contexts are derived from the implementations present in :mod:`arraycontext.impl`. +Only :mod:`pytato`-based array contexts have been implemented so far. Following :mod:`pytato`-based array context are provided: @@ -49,7 +46,9 @@ from typing import ( Any, Callable, + Iterable, Mapping, + Optional, Type, ) @@ -59,6 +58,7 @@ from pytato.array import ( Array, AxesT, + AxisPermutation, ShapeType, make_dict_of_named_arrays, make_placeholder as make_placeholder, @@ -79,12 +79,16 @@ deserialize_container as deserialize_container, is_array_container_type, serialize_container as serialize_container, + NotAnArrayContainerError, ) -from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.context import ( ArrayContext, ArrayOrContainerT, ) +from arraycontext.container.traversal import ( + rec_keyed_map_array_container, + rec_map_reduce_array_container, +) from arraycontext.impl.pytato import ( PytatoPyOpenCLArrayContext, ) @@ -126,11 +130,6 @@ class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) - def transform_loopy_program(self, - t_unit: lp.TranslationUnit) -> lp.TranslationUnit: - # Update in a subclass if you want. - return t_unit - # }}} @@ -164,10 +163,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: # On a cache hit we do not need to modify anything. return compiled_f(arg_id_to_arg) - with open("calls_to_compile", "a+") as my_file: - my_file.write(str(arg_id_to_descr)) - my_file.write("\n") - dict_of_named_arrays = {} output_id_to_name_in_program = {} input_id_to_name_in_program = { @@ -233,31 +228,31 @@ def _as_dict_of_named_arrays(keys, ary): return compiled_func(arg_id_to_arg) -def _cut_to_single_instance_size(name, arg) -> Array: +def _cut_to_single_instance_size(name: str, arg: Array) -> Array: """ - Helper to split a place holder into the base instance shape - if it is tagged with a `ParameterStudyAxisTag` - to ensure the survival of the information those tags will be converted - to temporary Array Tags of the same type. The placeholder will not - have the axes marked with a `ParameterStudyAxisTag` tag. - - We need to cut the extra axes off because we cannot assume - that the operators we use to build the single instance program - will understand what to do with the extra axes. + Helper function to create a placeholder of the single instance size. + Axes that are removed are those which are marked with a + :class:`ParameterStudyAxisTag`. + + We need to cut the extra axes off, because we cannot assume that + the operators we use to build the single instance program will + understand what to do with the extra axes. We are doing it after the + call to _to_input_for_compiled in order to ensure that we have an + :class:`Array` for arg. Also this way we allow the metadata materializer + to work. See :function:`~arraycontext.impl.pytato._to_input_for_compiled` for more + information. """ ndim: int = len(arg.shape) - newshape: ShapeType = () - update_axes: AxesT = () + single_inst_shape: ShapeType = () + single_inst_axes: AxesT = () for i in range(ndim): axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag) if not axis_tags: - update_axes = (*update_axes, arg.axes[i],) - newshape = (*newshape, arg.shape[i]) - - update_tags: frozenset[Tag] = arg.tags + single_inst_axes = (*single_inst_axes, arg.axes[i],) + single_inst_shape = (*single_inst_shape, arg.shape[i]) - return make_placeholder(name, newshape, arg.dtype, axes=update_axes, - tags=update_tags) + return make_placeholder(name, single_inst_shape, arg.dtype, axes=single_inst_axes, + tags=arg.tags) def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): @@ -278,6 +273,7 @@ def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): elif isinstance(arg, Array): name = arg_id_to_name[(kw,)] # Transform the DAG to give metadata inference a chance to do its job + breakpoint() arg = _to_input_for_compiled(arg, actx) return _cut_to_single_instance_size(name, arg) elif is_array_container_type(arg.__class__): @@ -305,16 +301,21 @@ def pack_for_parameter_study(actx: ArrayContext, assert len(args) > 0 - def recursive_stack(*args: Array) -> Array: + def _recursive_stack(*args: Array) -> Array: assert len(args) > 0 + thawed_args: ArraysT = () for val in args: assert not is_array_container_type(type(val)) - assert isinstance(val, Array) + if not isinstance(val, Array): + thawed_args = (*thawed_args, actx.thaw(val),) + else: + thawed_args = (*thawed_args, val) - orig_shape = args[0].shape - out = actx.np.stack(args, axis=len(orig_shape)) - out = out.with_tagged_axis(len(orig_shape), [study_name_tag_type(len(args))]) + orig_shape = thawed_args[0].shape + out = actx.np.stack(thawed_args, axis=len(orig_shape)) + out = out.with_tagged_axis(len(orig_shape), + [study_name_tag_type(len(args))]) # We have added a new axis. assert len(orig_shape) + 1 == len(out.shape) @@ -328,40 +329,59 @@ def recursive_stack(*args: Array) -> Array: # assert isinstance(get_container_context_recursively(args[0]), type(actx)) # assert isinstance(actx, get_container_context_recursively(args[0])) - return rec_multimap_array_container(recursive_stack, *args) + return rec_multimap_array_container(_recursive_stack, *args) - return recursive_stack(*args) + return _recursive_stack(*args) -def unpack_parameter_study(data: Array, +from arraycontext.container import ( + serialize_container, + deserialize_container, +) +def unpack_parameter_study(data: ArrayOrContainerT, study_name_tag_type: ParamStudyTagT) -> Mapping[int, - list[Array]]: + ArrayOrContainerT]: """ - Split the data array along the axes which vary according to - a ParameterStudyAxisTag whose name tag is an instance study_name_tag_type. - - output[i] corresponds to the values associated with the ith parameter study that - uses the variable name :arg: `study_name_tag_type`. + Recurse through the data structure and split the data along the + axis which corresponds to the input tag name. """ - ndim: int = len(data.shape) - out: dict[int, list[Array]] = {} - study_count = 0 - for i in range(ndim): - axis_tags = data.axes[i].tags_of_type(study_name_tag_type) - if axis_tags: - # Now we need to split this data. - for j in range(data.shape[i]): - tmp: list[Any] = [slice(None)] * ndim - tmp[i] = j - the_slice = tuple(tmp) - # Needs to be a tuple of slices not list of slices. - if study_count in out.keys(): - out[study_count].append(data[the_slice]) - else: - out[study_count] = [data[the_slice]] - if study_count in out.keys(): + def _recursive_split_helper(data: Array) -> Mapping[int, Array]: + """ + Split the data array along the axes which vary according to a + ParameterStudyAxisTag whose name tag is an instance study_name_tag_type. + """ + + ndim: int = len(data.shape) + out: list[Array] = [] + + breakpoint() + study_count = 0 + for i in range(ndim): + axis_tags = data.axes[i].tags_of_type(study_name_tag_type) + if axis_tags: study_count += 1 + # Now we need to split this data. + for j in range(data.shape[i]): + tmp: list[slice | int] = [slice(None)] * ndim + tmp[i] = j + the_slice = tuple(tmp) + # Needs to be a tuple of slices not list of slices. + out.append(data[the_slice]) + + assert study_count == 1 + + return out + + def reduce_func(iterable): + breakpoint() + return deserialize_container(data, iterable) + + if is_array_container_type(data.__class__): + # We need to recurse through the system and emit out the indexed arrays. + breakpoint() + return rec_map_reduce_array_container(_recursive_split_helper, reduce_func, data) + return rec_map_array_container(_recursive_split_helper, data) - return out + return _recursive_split_helper(data)