From 9a8a7ba337ba974d41b16250a850b6f6667457f6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 14 Mar 2023 12:40:44 -0500 Subject: [PATCH 01/14] move _ary_container_key_stringifier to utils.py --- arraycontext/impl/pytato/__init__.py | 4 ++-- arraycontext/impl/pytato/compile.py | 23 +---------------------- arraycontext/impl/pytato/utils.py | 26 ++++++++++++++++++++++++++ test/test_utils.py | 2 +- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index e3ce52a7..86bad821 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -443,8 +443,8 @@ def freeze(self, array): TaggableCLArray, to_tagged_cl_array, ) - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier from arraycontext.impl.pytato.utils import ( + _ary_container_key_stringifier, _normalize_pt_expr, get_cl_axes_from_pt_axes, ) @@ -779,7 +779,7 @@ def freeze(self, array): import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier array_as_dict: dict[str, jnp.ndarray | pt.Array] = {} key_to_frozen_subary: dict[str, jnp.ndarray] = {} diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e77c1091..dfcd4bfd 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -110,28 +110,6 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: - """ - Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an - array-container's component's key. Goals of this routine: - - * No two different keys should have the same stringification - * Stringified key must a valid identifier according to :meth:`str.isidentifier` - * (informal) Shorter identifiers are preferred - """ - def _rec_str(key: Any) -> str: - if isinstance(key, str | int): - return str(key) - elif isinstance(key, tuple): - # t in '_actx_t': stands for tuple - return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" - else: - raise NotImplementedError("Key-stringication unimplemented for " - f"'{type(key).__name__}'.") - - return "_".join(_rec_str(key) for key in keys) - - def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...], kwargs: Mapping[str, Any] ) -> \ @@ -322,6 +300,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 2457e297..f9970b4a 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -221,4 +221,30 @@ def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: # }}} + +# {{{ compile/outline helpers + +def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: + """ + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an + array-container's component's key. Goals of this routine: + + * No two different keys should have the same stringification + * Stringified key must a valid identifier according to :meth:`str.isidentifier` + * (informal) Shorter identifiers are preferred + """ + def _rec_str(key: Any) -> str: + if isinstance(key, str | int): + return str(key) + elif isinstance(key, tuple): + # t in '_actx_t': stands for tuple + return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" + else: + raise NotImplementedError("Key-stringication unimplemented for " + f"'{type(key).__name__}'.") + + return "_".join(_rec_str(key) for key in keys) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_utils.py b/test/test_utils.py index eeef7723..a44c6035 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -39,7 +39,7 @@ # {{{ test_pt_actx_key_stringification_uniqueness def test_pt_actx_key_stringification_uniqueness(): - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier assert (_ary_container_key_stringifier(((3, 2), 3)) != _ary_container_key_stringifier((3, (2, 3)))) From c45d2b7a271865c3d06e22ff9ae79a0b602b0c6f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 14 Mar 2023 11:31:02 -0500 Subject: [PATCH 02/14] Add outlining pass to array expression --- arraycontext/context.py | 19 +++ arraycontext/impl/pytato/__init__.py | 14 +++ arraycontext/impl/pytato/outline.py | 180 +++++++++++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 arraycontext/impl/pytato/outline.py diff --git a/arraycontext/context.py b/arraycontext/context.py index 210a9b89..6343894e 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -328,6 +328,7 @@ class ArrayContext(ABC): .. automethod:: tag .. automethod:: tag_axis .. automethod:: compile + .. automethod:: outline """ array_types: tuple[type, ...] = () @@ -577,6 +578,24 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: """ return f + def outline(self, + f: Callable[..., Any], + name: str | None = None) -> Callable[..., Any]: + """ + Returns a drop-in-replacement for *f*. The behavior of the returned + callable is specific to the derived class. + + The reason for the existence of such a routine is mainly for + arraycontexts that allow a lazy mode of execution. In such + arraycontexts, the computations within *f* maybe staged to potentially + enable additional compiler transformations. See + :func:`pytato.trace_call` or :func:`jax.named_call` for examples. + + :arg f: the function executing the computation to be staged. + :return: a function with the same signature as *f*. + """ + return f + # undocumented for now @property @abstractmethod diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 86bad821..a07e901f 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -226,6 +226,20 @@ def get_target(self): # }}} + def outline(self, + f: Callable[..., Any], + name: str | None = None, + tags: frozenset[Tag] = frozenset() + ) -> Callable[..., Any]: + from pytato.tags import FunctionIdentifier + + from .outline import OutlinedCall + name = name or getattr(f, "__name__", None) + if name is not None: + tags = tags | {FunctionIdentifier(name)} + + return OutlinedCall(self, f, tags) + # }}} diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py new file mode 100644 index 00000000..b8ac6567 --- /dev/null +++ b/arraycontext/impl/pytato/outline.py @@ -0,0 +1,180 @@ +from __future__ import annotations + + +__doc__ = """ +.. autoclass:: OutlinedCall +""" +__copyright__ = """ +Copyright (C) 2023-5 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import itertools +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import Any + +import numpy as np +from immutables import Map + +import pytato as pt +from pytools.tag import Tag + +from arraycontext.container import is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.context import ArrayOrContainer +from arraycontext.impl.pytato import _BasePytatoArrayContext + + +def _get_arg_id_to_arg(args: tuple[Any, ...], + kwargs: Mapping[str, Any] + ) -> Map[tuple[Any, ...], Any]: + """ + Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id + to argument values. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + arg_id_to_arg: dict[tuple[Any, ...], Any] = {} + + for kw, arg in itertools.chain(enumerate(args), + kwargs.items()): + if np.isscalar(arg): + # do not make scalars as placeholders since we inline them. + pass + elif is_array_container_type(arg.__class__): + def id_collector(keys, ary): + arg_id = (kw, *keys) # noqa: B023 + arg_id_to_arg[arg_id] = ary + return ary + + rec_keyed_map_array_container(id_collector, arg) + elif isinstance(arg, pt.Array): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar, pt.Array or an array container. Got" + f" '{arg}'.") + + return Map(arg_id_to_arg) + + +def _get_placeholder_replacement(arg, kw, arg_id_to_name): + """ + Helper for :class:`OutlinedCall.__call__`. Returns the placeholder version + of an argument to :attr:`OutlinedCall.f`. + """ + if np.isscalar(arg): + return arg + elif isinstance(arg, pt.Array): + name = arg_id_to_name[kw,] + return pt.make_placeholder(name, arg.shape, arg.dtype) + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + name = arg_id_to_name[(kw, *keys)] + return pt.make_placeholder(name, + ary.shape, + ary.dtype) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + +def _get_input_arg_id_str(arg_id: tuple[Any, ...]) -> str: + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + + +def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str: + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_out_{_ary_container_key_stringifier(arg_id)}" + + +@dataclass(frozen=True) +class OutlinedCall: + actx: _BasePytatoArrayContext + f: Callable[..., Any] + tags: frozenset[Tag] + + def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer: + arg_id_to_arg = _get_arg_id_to_arg(args, kwargs) + input_id_to_name_in_function = {arg_id: _get_input_arg_id_str(arg_id) + for arg_id in arg_id_to_arg} + + pl_args = [_get_placeholder_replacement(arg, iarg, + input_id_to_name_in_function) + for iarg, arg in enumerate(args)] + pl_kwargs = {kw: _get_placeholder_replacement(arg, kw, + input_id_to_name_in_function) + for kw, arg in kwargs.items()} + + output = self.f(*pl_args, **pl_kwargs) + + if isinstance(output, pt.Array): + returns = {"_": output} + ret_type = pt.function.ReturnType.ARRAY + elif is_array_container_type(output.__class__): + returns = {} + + def _unpack_container(key, ary): + key = _get_output_arg_id_str(key) + returns[key] = ary + return ary + + rec_keyed_map_array_container(_unpack_container, output) + ret_type = pt.function.ReturnType.DICT_OF_ARRAYS + else: + raise NotImplementedError(type(output)) + + # pylint-disable-reason: pylint has a hard time with kw_only fields in + # dataclasses + + # pylint: disable=unexpected-keyword-arg + func_def = pt.function.FunctionDefinition( + parameters=frozenset(input_id_to_name_in_function.values()), + return_type=ret_type, + returns=Map(returns), + tags=self.tags, + ) + + call_parameters = {input_id_to_name_in_function[arg_id]: arg + for arg_id, arg in arg_id_to_arg.items()} + call_site_output = func_def(**call_parameters) + + if isinstance(output, pt.Array): + return call_site_output + elif is_array_container_type(output.__class__): + def _pack_into_container(key, ary): + key = _get_output_arg_id_str(key) + return call_site_output[key] + + call_site_output_as_container = rec_keyed_map_array_container( + _pack_into_container, + output) + return call_site_output_as_container + else: + raise NotImplementedError(type(output)) + + +# vim: foldmethod=marker From 5b6b137f965600683fbb5674eaabf3189eabcb4f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 14 Mar 2023 15:32:18 -0500 Subject: [PATCH 03/14] adds an outlining example --- examples/how_to_outline.py | 88 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 examples/how_to_outline.py diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py new file mode 100644 index 00000000..14fdf9bc --- /dev/null +++ b/examples/how_to_outline.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import dataclasses as dc + +import numpy as np + +import pytato as pt +from pytools.obj_array import make_obj_array + +from arraycontext import ( + Array, + PytatoJAXArrayContext as BasePytatoJAXArrayContext, + dataclass_array_container, + with_container_arithmetic, +) + + +Ncalls = 300 + + +class PytatoJAXArrayContext(BasePytatoJAXArrayContext): + def transform_dag(self, dag): + # Test 1: Test that the number of untransformed call sites are as + # expected + assert pt.analysis.get_num_call_sites(dag) == Ncalls + + dag = pt.tag_all_calls_to_be_inlined(dag) + print("[Pre-concatenation] Number of nodes =", + pt.analysis.get_num_nodes(pt.inline_calls(dag))) + dag = pt.concatenate_calls( + dag, + lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags + ) + + # Test 2: Test that only one call-sites is left post concatenation + assert pt.analysis.get_num_call_sites(dag) == 1 + + dag = pt.inline_calls(dag) + print("[Post-concatenation] Number of nodes =", + pt.analysis.get_num_nodes(dag)) + + return dag + + +actx = PytatoJAXArrayContext() + + +@with_container_arithmetic( + bcast_obj_array=True, + eq_comparison=False, + rel_comparison=False, +) +@dataclass_array_container +@dc.dataclass(frozen=True) +class State: + mass: Array | np.ndarray + vel: np.ndarray # np array of Arrays or numpy arrays + + +@actx.outline +def foo(x1, x2): + return (2*x1 + 3*x2 + x1**3 + x2**4 + + actx.np.minimum(2*x1, 4*x2) + + actx.np.maximum(7*x1, 8*x2)) + + +rng = np.random.default_rng(0) +Ndof = 10 +Ndim = 3 + +results = [] + +for _ in range(Ncalls): + Nel = rng.integers(low=4, high=17) + state1_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + state2_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + + state1 = actx.from_numpy(state1_np) + state2 = actx.from_numpy(state2_np) + results.append(foo(state1, state2)) + +actx.to_numpy(make_obj_array(results)) From 029fe482ca9f41e9d855db10f6e5a219b9d6f7a4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 14 Jun 2024 16:48:47 -0500 Subject: [PATCH 04/14] cosmetic-ish refactor of OutlinedCall.__call__ includes minor change to handle array containers that contain scalars (e.g., a TracePair of constant diffusion coefs) --- arraycontext/impl/pytato/outline.py | 171 +++++++++++++++++----------- 1 file changed, 105 insertions(+), 66 deletions(-) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py index b8ac6567..b9afda46 100644 --- a/arraycontext/impl/pytato/outline.py +++ b/arraycontext/impl/pytato/outline.py @@ -63,8 +63,11 @@ def _get_arg_id_to_arg(args: tuple[Any, ...], pass elif is_array_container_type(arg.__class__): def id_collector(keys, ary): - arg_id = (kw, *keys) # noqa: B023 - arg_id_to_arg[arg_id] = ary + if np.isscalar(ary): + pass + else: + arg_id = (kw, *keys) # noqa: B023 + arg_id_to_arg[arg_id] = ary return ary rec_keyed_map_array_container(id_collector, arg) @@ -79,28 +82,6 @@ def id_collector(keys, ary): return Map(arg_id_to_arg) -def _get_placeholder_replacement(arg, kw, arg_id_to_name): - """ - Helper for :class:`OutlinedCall.__call__`. Returns the placeholder version - of an argument to :attr:`OutlinedCall.f`. - """ - if np.isscalar(arg): - return arg - elif isinstance(arg, pt.Array): - name = arg_id_to_name[kw,] - return pt.make_placeholder(name, arg.shape, arg.dtype) - elif is_array_container_type(arg.__class__): - def _rec_to_placeholder(keys, ary): - name = arg_id_to_name[(kw, *keys)] - return pt.make_placeholder(name, - ary.shape, - ary.dtype) - - return rec_keyed_map_array_container(_rec_to_placeholder, arg) - else: - raise NotImplementedError(type(arg)) - - def _get_input_arg_id_str(arg_id: tuple[Any, ...]) -> str: from arraycontext.impl.pytato.utils import _ary_container_key_stringifier return f"_actx_in_{_ary_container_key_stringifier(arg_id)}" @@ -111,6 +92,91 @@ def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str: return f"_actx_out_{_ary_container_key_stringifier(arg_id)}" +def _get_arg_id_to_placeholder( + arg_id_to_arg: Mapping[tuple[Any, ...], Any], + ) -> Map[tuple[Any, ...], pt.Placeholder]: + """ + Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder` + for each argument in *arg_id_to_arg*. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + return Map({ + arg_id: pt.make_placeholder( + _get_input_arg_id_str(arg_id), + arg.shape, + arg.dtype) + for arg_id, arg in arg_id_to_arg.items()}) + + +def _call_with_placeholders( + f: Callable[..., Any], + args: tuple[Any], + kwargs: Mapping[str, Any], + arg_id_to_placeholder: Mapping[tuple[Any, ...], pt.Placeholder]) -> Any: + """ + Construct placeholders analogous to *args* and *kwargs* and call *f*. + """ + def get_placeholder_replacement(arg, key): + if np.isscalar(arg): + return arg + elif isinstance(arg, pt.Array): + return arg_id_to_placeholder[key] + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + return get_placeholder_replacement(ary, key + keys) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + pl_args = [get_placeholder_replacement(arg, (iarg,)) + for iarg, arg in enumerate(args)] + pl_kwargs = {kw: get_placeholder_replacement(arg, (kw,)) + for kw, arg in kwargs.items()} + + return f(*pl_args, **pl_kwargs) + + +def _unpack_output( + output: ArrayOrContainer) -> pt.Array | dict[str, pt.Array]: + """Unpack any array containers in *output*.""" + if isinstance(output, pt.Array): + return output + elif is_array_container_type(output.__class__): + unpacked_output = {} + + def _unpack_container(key, ary): + key = _get_output_arg_id_str(key) + unpacked_output[key] = ary + return ary + + rec_keyed_map_array_container(_unpack_container, output) + + return unpacked_output + else: + raise NotImplementedError(type(output)) + + +def _pack_output( + output_template: ArrayOrContainer, + unpacked_output: pt.Array | tuple[pt.Array, ...] | Mapping[str, pt.Array] + ) -> ArrayOrContainer: + """ + Pack *unpacked_output* into array containers according to *output_template*. + """ + if isinstance(output_template, pt.Array): + return unpacked_output + elif is_array_container_type(output_template.__class__): + def _pack_into_container(key, ary): + key = _get_output_arg_id_str(key) + return unpacked_output[key] + + return rec_keyed_map_array_container(_pack_into_container, output_template) + else: + raise NotImplementedError(type(output_template)) + + @dataclass(frozen=True) class OutlinedCall: actx: _BasePytatoArrayContext @@ -119,62 +185,35 @@ class OutlinedCall: def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer: arg_id_to_arg = _get_arg_id_to_arg(args, kwargs) - input_id_to_name_in_function = {arg_id: _get_input_arg_id_str(arg_id) - for arg_id in arg_id_to_arg} - - pl_args = [_get_placeholder_replacement(arg, iarg, - input_id_to_name_in_function) - for iarg, arg in enumerate(args)] - pl_kwargs = {kw: _get_placeholder_replacement(arg, kw, - input_id_to_name_in_function) - for kw, arg in kwargs.items()} - output = self.f(*pl_args, **pl_kwargs) + arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg) - if isinstance(output, pt.Array): - returns = {"_": output} + output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder) + unpacked_output = _unpack_output(output) + if isinstance(unpacked_output, pt.Array): + unpacked_output = {"_": unpacked_output} ret_type = pt.function.ReturnType.ARRAY - elif is_array_container_type(output.__class__): - returns = {} - - def _unpack_container(key, ary): - key = _get_output_arg_id_str(key) - returns[key] = ary - return ary - - rec_keyed_map_array_container(_unpack_container, output) - ret_type = pt.function.ReturnType.DICT_OF_ARRAYS else: - raise NotImplementedError(type(output)) + ret_type = pt.function.ReturnType.DICT_OF_ARRAYS + + call_bindings = { + placeholder.name: arg_id_to_arg[arg_id] + for arg_id, placeholder in arg_id_to_placeholder.items()} # pylint-disable-reason: pylint has a hard time with kw_only fields in # dataclasses # pylint: disable=unexpected-keyword-arg func_def = pt.function.FunctionDefinition( - parameters=frozenset(input_id_to_name_in_function.values()), + parameters=frozenset(call_bindings.keys()), return_type=ret_type, - returns=Map(returns), + returns=Map(unpacked_output), tags=self.tags, ) - call_parameters = {input_id_to_name_in_function[arg_id]: arg - for arg_id, arg in arg_id_to_arg.items()} - call_site_output = func_def(**call_parameters) - - if isinstance(output, pt.Array): - return call_site_output - elif is_array_container_type(output.__class__): - def _pack_into_container(key, ary): - key = _get_output_arg_id_str(key) - return call_site_output[key] - - call_site_output_as_container = rec_keyed_map_array_container( - _pack_into_container, - output) - return call_site_output_as_container - else: - raise NotImplementedError(type(output)) + call_site_output = func_def(**call_bindings) + + return _pack_output(output, call_site_output) # vim: foldmethod=marker From 67740a2fbb252d02c063b6dfe9da1dcd7a2e012d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 14 Jun 2024 16:57:44 -0500 Subject: [PATCH 05/14] check for non-argument placeholders in outlined function --- arraycontext/impl/pytato/outline.py | 36 +++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py index b9afda46..7f06a470 100644 --- a/arraycontext/impl/pytato/outline.py +++ b/arraycontext/impl/pytato/outline.py @@ -82,9 +82,12 @@ def id_collector(keys, ary): return Map(arg_id_to_arg) -def _get_input_arg_id_str(arg_id: tuple[Any, ...]) -> str: +def _get_input_arg_id_str( + arg_id: tuple[Any, ...], prefix: str | None = None) -> str: + if prefix is None: + prefix = "" from arraycontext.impl.pytato.utils import _ary_container_key_stringifier - return f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + return f"_actx_{prefix}_in_{_ary_container_key_stringifier(arg_id)}" def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str: @@ -94,7 +97,7 @@ def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str: def _get_arg_id_to_placeholder( arg_id_to_arg: Mapping[tuple[Any, ...], Any], - ) -> Map[tuple[Any, ...], pt.Placeholder]: + prefix: str | None = None) -> Map[tuple[Any, ...], pt.Placeholder]: """ Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder` for each argument in *arg_id_to_arg*. See @@ -103,7 +106,7 @@ def _get_arg_id_to_placeholder( """ return Map({ arg_id: pt.make_placeholder( - _get_input_arg_id_str(arg_id), + _get_input_arg_id_str(arg_id, prefix=prefix), arg.shape, arg.dtype) for arg_id, arg in arg_id_to_arg.items()}) @@ -186,6 +189,31 @@ class OutlinedCall: def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer: arg_id_to_arg = _get_arg_id_to_arg(args, kwargs) + if __debug__: + # Add a prefix to the names to distinguish them from any existing + # placeholders + arg_id_to_prefixed_placeholder = _get_arg_id_to_placeholder( + arg_id_to_arg, prefix="outlined_call") + + prefixed_output = _call_with_placeholders( + self.f, args, kwargs, arg_id_to_prefixed_placeholder) + unpacked_prefixed_output = _unpack_output(prefixed_output) + if isinstance(unpacked_prefixed_output, pt.Array): + unpacked_prefixed_output = {"_": unpacked_prefixed_output} + + prefixed_placeholders = frozenset( + arg_id_to_prefixed_placeholder.values()) + + found_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()( + pt.make_dict_of_named_arrays(unpacked_prefixed_output)) + if isinstance(arg, pt.Placeholder)}) + + extra_placeholders = found_placeholders - prefixed_placeholders + assert not extra_placeholders, \ + "Found non-argument placeholder " \ + f"'{next(iter(extra_placeholders)).name}' in outlined function." + arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg) output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder) From 914d575f0bd8abb8c1f9337cbbd127170cefa9ce Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 8 Mar 2024 16:55:03 -0600 Subject: [PATCH 06/14] drop unused function arguments --- arraycontext/impl/pytato/outline.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py index 7f06a470..3fbccef0 100644 --- a/arraycontext/impl/pytato/outline.py +++ b/arraycontext/impl/pytato/outline.py @@ -224,9 +224,15 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer: else: ret_type = pt.function.ReturnType.DICT_OF_ARRAYS + used_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()( + pt.make_dict_of_named_arrays(unpacked_output)) + if isinstance(arg, pt.Placeholder)}) + call_bindings = { placeholder.name: arg_id_to_arg[arg_id] - for arg_id, placeholder in arg_id_to_placeholder.items()} + for arg_id, placeholder in arg_id_to_placeholder.items() + if placeholder in used_placeholders} # pylint-disable-reason: pylint has a hard time with kw_only fields in # dataclasses From 5793c2e7dc6d41ae4203031b7a6ee9f1b770794d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 17 Jan 2025 12:38:45 -0600 Subject: [PATCH 07/14] pass hashable instead of string as id to outline --- arraycontext/context.py | 5 +++-- arraycontext/impl/pytato/__init__.py | 11 ++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 6343894e..1b5ab16e 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -167,7 +167,7 @@ """ from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable, Hashable, Mapping from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload from warnings import warn @@ -580,7 +580,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: def outline(self, f: Callable[..., Any], - name: str | None = None) -> Callable[..., Any]: + *, + id: Hashable | None = None) -> Callable[..., Any]: """ Returns a drop-in-replacement for *f*. The behavior of the returned callable is specific to the derived class. diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index a07e901f..52f80e50 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -53,7 +53,7 @@ import abc import sys -from collections.abc import Callable +from collections.abc import Callable, Hashable from typing import TYPE_CHECKING, Any import numpy as np @@ -228,15 +228,16 @@ def get_target(self): def outline(self, f: Callable[..., Any], - name: str | None = None, + *, + id: Hashable | None = None, tags: frozenset[Tag] = frozenset() ) -> Callable[..., Any]: from pytato.tags import FunctionIdentifier from .outline import OutlinedCall - name = name or getattr(f, "__name__", None) - if name is not None: - tags = tags | {FunctionIdentifier(name)} + id = id or getattr(f, "__name__", None) + if id is not None: + tags = tags | {FunctionIdentifier(id)} return OutlinedCall(self, f, tags) From 8023fc8fc9b8d8dd2f3cf653b5047fb4979e458d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 14 Mar 2024 18:36:57 -0500 Subject: [PATCH 08/14] handle optional arguments that are passed as None explicitly --- arraycontext/impl/pytato/outline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py index 3fbccef0..3df20c7b 100644 --- a/arraycontext/impl/pytato/outline.py +++ b/arraycontext/impl/pytato/outline.py @@ -58,7 +58,9 @@ def _get_arg_id_to_arg(args: tuple[Any, ...], for kw, arg in itertools.chain(enumerate(args), kwargs.items()): - if np.isscalar(arg): + if arg is None: + pass + elif np.isscalar(arg): # do not make scalars as placeholders since we inline them. pass elif is_array_container_type(arg.__class__): @@ -121,7 +123,9 @@ def _call_with_placeholders( Construct placeholders analogous to *args* and *kwargs* and call *f*. """ def get_placeholder_replacement(arg, key): - if np.isscalar(arg): + if arg is None: + return None + elif np.isscalar(arg): return arg elif isinstance(arg, pt.Array): return arg_id_to_placeholder[key] From b3fa7d6e90e23a7065c60eba5137421900f7cc5e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 13 Jun 2024 16:20:22 -0500 Subject: [PATCH 09/14] don't tag NamedArray (they inherit tags from their corresponding _container entry) --- arraycontext/impl/pytato/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 52f80e50..a88f7ce3 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -689,7 +689,7 @@ def preprocess_arg(name, arg): # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if (not isinstance(ary, pt.Placeholder) + if (not isinstance(ary, pt.Placeholder | pt.NamedArray) and not ary.tags_of_type(NameHint)): ary = ary.tagged(NameHint(name)) From d68010ada24f351d279f0b37f3aab5639c22ef91 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 6 Jun 2024 13:30:57 -0500 Subject: [PATCH 10/14] change Map -> immutabledict in outlining --- arraycontext/impl/pytato/outline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py index 3df20c7b..03a3ccdc 100644 --- a/arraycontext/impl/pytato/outline.py +++ b/arraycontext/impl/pytato/outline.py @@ -34,7 +34,7 @@ from typing import Any import numpy as np -from immutables import Map +from immutabledict import immutabledict import pytato as pt from pytools.tag import Tag @@ -47,7 +47,7 @@ def _get_arg_id_to_arg(args: tuple[Any, ...], kwargs: Mapping[str, Any] - ) -> Map[tuple[Any, ...], Any]: + ) -> immutabledict[tuple[Any, ...], Any]: """ Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id to argument values. See @@ -81,7 +81,7 @@ def id_collector(keys, ary): " either a scalar, pt.Array or an array container. Got" f" '{arg}'.") - return Map(arg_id_to_arg) + return immutabledict(arg_id_to_arg) def _get_input_arg_id_str( @@ -99,14 +99,14 @@ def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str: def _get_arg_id_to_placeholder( arg_id_to_arg: Mapping[tuple[Any, ...], Any], - prefix: str | None = None) -> Map[tuple[Any, ...], pt.Placeholder]: + prefix: str | None = None) -> immutabledict[tuple[Any, ...], pt.Placeholder]: """ Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder` for each argument in *arg_id_to_arg*. See :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's representation. """ - return Map({ + return immutabledict({ arg_id: pt.make_placeholder( _get_input_arg_id_str(arg_id, prefix=prefix), arg.shape, @@ -245,7 +245,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer: func_def = pt.function.FunctionDefinition( parameters=frozenset(call_bindings.keys()), return_type=ret_type, - returns=Map(unpacked_output), + returns=immutabledict(unpacked_output), tags=self.tags, ) From 6197a7afab4146a61d04474315ac015e04f2c06c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 14 Jun 2024 12:16:15 -0500 Subject: [PATCH 11/14] add FIXME --- arraycontext/context.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/arraycontext/context.py b/arraycontext/context.py index 1b5ab16e..631c2871 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -578,6 +578,15 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: """ return f + # FIXME: Think about making this a standalone function? Would make it easier to + # pass arguments when used as a decorator, e.g.: + # @outline(actx, id=...) + # def func(...): + # vs. + # outline = partial(actx.outline, id=...) + # + # @outline + # def func(...): def outline(self, f: Callable[..., Any], *, From 8e670306fabe276f10da84b9613da64f491adddb Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 9 Sep 2024 21:17:54 -0500 Subject: [PATCH 12/14] forbid calling _normalize_pt_expr on a DAG with function calls for now --- arraycontext/impl/pytato/utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index f9970b4a..919a16fa 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -35,6 +35,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, cast +from pytato.analysis import get_num_call_sites from pytato.array import ( AbstractResultWithNamedArrays, Array, @@ -45,6 +46,7 @@ SizeParam, make_placeholder, ) +from pytato.function import FunctionDefinition from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.transform import ArrayOrNames, CopyMapper from pytools import UniqueNameGenerator, memoize_method @@ -95,7 +97,14 @@ def map_placeholder(self, expr: Placeholder) -> Array: raise ValueError("Placeholders cannot appear in" " DatawrapperToBoundPlaceholderMapper.") + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + raise ValueError("Function definitions cannot appear in" + " DatawrapperToBoundPlaceholderMapper.") + +# FIXME: This strategy doesn't work if the DAG has functions, since function +# definitions can't contain non-argument placeholders def _normalize_pt_expr( expr: DictOfNamedArrays ) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]: @@ -108,6 +117,11 @@ def _normalize_pt_expr( Deterministic naming of placeholders permits more effective caching of equivalent graphs. """ + if get_num_call_sites(expr): + raise NotImplementedError( + "_normalize_pt_expr is not compatible with expressions that " + "contain function calls.") + normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalized_expr = normalize_mapper(expr) assert isinstance(normalized_expr, AbstractResultWithNamedArrays) From 6d7674b66d7a782830ba9ae06db4447a6285b370 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 24 Oct 2024 10:02:59 -0500 Subject: [PATCH 13/14] inline calls in freeze before _normalize_pt_expr --- arraycontext/impl/pytato/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index a88f7ce3..fb25cbdb 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -522,6 +522,12 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( key_to_pt_arrays) + + # FIXME: Remove this if/when _normalize_pt_expr gets support for functions + pt_dict_of_named_arrays = pt.tag_all_calls_to_be_inlined( + pt_dict_of_named_arrays) + pt_dict_of_named_arrays = pt.inline_calls(pt_dict_of_named_arrays) + normalized_expr, bound_arguments = _normalize_pt_expr( pt_dict_of_named_arrays) From 92e51f2ad1644b2d6c733a8e439f7c380de50955 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 17 Jan 2025 16:59:42 -0600 Subject: [PATCH 14/14] disable concatenation in outlining example for now --- examples/how_to_outline.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py index 14fdf9bc..bb04968d 100644 --- a/examples/how_to_outline.py +++ b/examples/how_to_outline.py @@ -25,19 +25,21 @@ def transform_dag(self, dag): assert pt.analysis.get_num_call_sites(dag) == Ncalls dag = pt.tag_all_calls_to_be_inlined(dag) - print("[Pre-concatenation] Number of nodes =", - pt.analysis.get_num_nodes(pt.inline_calls(dag))) - dag = pt.concatenate_calls( - dag, - lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags - ) - - # Test 2: Test that only one call-sites is left post concatenation - assert pt.analysis.get_num_call_sites(dag) == 1 - + # FIXME: Re-enable this when concatenation is added to pytato + # print("[Pre-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(pt.inline_calls(dag))) + # dag = pt.concatenate_calls( + # dag, + # lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags + # ) + # + # # Test 2: Test that only one call-sites is left post concatenation + # assert pt.analysis.get_num_call_sites(dag) == 1 + # + # dag = pt.inline_calls(dag) + # print("[Post-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(dag)) dag = pt.inline_calls(dag) - print("[Post-concatenation] Number of nodes =", - pt.analysis.get_num_nodes(dag)) return dag