diff --git a/arraycontext/context.py b/arraycontext/context.py index 210a9b89..631c2871 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -167,7 +167,7 @@ """ from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable, Hashable, Mapping from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload from warnings import warn @@ -328,6 +328,7 @@ class ArrayContext(ABC): .. automethod:: tag .. automethod:: tag_axis .. automethod:: compile + .. automethod:: outline """ array_types: tuple[type, ...] = () @@ -577,6 +578,34 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: """ return f + # FIXME: Think about making this a standalone function? Would make it easier to + # pass arguments when used as a decorator, e.g.: + # @outline(actx, id=...) + # def func(...): + # vs. + # outline = partial(actx.outline, id=...) + # + # @outline + # def func(...): + def outline(self, + f: Callable[..., Any], + *, + id: Hashable | None = None) -> Callable[..., Any]: + """ + Returns a drop-in-replacement for *f*. The behavior of the returned + callable is specific to the derived class. + + The reason for the existence of such a routine is mainly for + arraycontexts that allow a lazy mode of execution. In such + arraycontexts, the computations within *f* maybe staged to potentially + enable additional compiler transformations. See + :func:`pytato.trace_call` or :func:`jax.named_call` for examples. + + :arg f: the function executing the computation to be staged. + :return: a function with the same signature as *f*. + """ + return f + # undocumented for now @property @abstractmethod diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index e3ce52a7..fb25cbdb 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -53,7 +53,7 @@ import abc import sys -from collections.abc import Callable +from collections.abc import Callable, Hashable from typing import TYPE_CHECKING, Any import numpy as np @@ -226,6 +226,21 @@ def get_target(self): # }}} + def outline(self, + f: Callable[..., Any], + *, + id: Hashable | None = None, + tags: frozenset[Tag] = frozenset() + ) -> Callable[..., Any]: + from pytato.tags import FunctionIdentifier + + from .outline import OutlinedCall + id = id or getattr(f, "__name__", None) + if id is not None: + tags = tags | {FunctionIdentifier(id)} + + return OutlinedCall(self, f, tags) + # }}} @@ -443,8 +458,8 @@ def freeze(self, array): TaggableCLArray, to_tagged_cl_array, ) - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier from arraycontext.impl.pytato.utils import ( + _ary_container_key_stringifier, _normalize_pt_expr, get_cl_axes_from_pt_axes, ) @@ -507,6 +522,12 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( key_to_pt_arrays) + + # FIXME: Remove this if/when _normalize_pt_expr gets support for functions + pt_dict_of_named_arrays = pt.tag_all_calls_to_be_inlined( + pt_dict_of_named_arrays) + pt_dict_of_named_arrays = pt.inline_calls(pt_dict_of_named_arrays) + normalized_expr, bound_arguments = _normalize_pt_expr( pt_dict_of_named_arrays) @@ -674,7 +695,7 @@ def preprocess_arg(name, arg): # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if (not isinstance(ary, pt.Placeholder) + if (not isinstance(ary, pt.Placeholder | pt.NamedArray) and not ary.tags_of_type(NameHint)): ary = ary.tagged(NameHint(name)) @@ -779,7 +800,7 @@ def freeze(self, array): import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier array_as_dict: dict[str, jnp.ndarray | pt.Array] = {} key_to_frozen_subary: dict[str, jnp.ndarray] = {} diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e77c1091..dfcd4bfd 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -110,28 +110,6 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: - """ - Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an - array-container's component's key. Goals of this routine: - - * No two different keys should have the same stringification - * Stringified key must a valid identifier according to :meth:`str.isidentifier` - * (informal) Shorter identifiers are preferred - """ - def _rec_str(key: Any) -> str: - if isinstance(key, str | int): - return str(key) - elif isinstance(key, tuple): - # t in '_actx_t': stands for tuple - return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" - else: - raise NotImplementedError("Key-stringication unimplemented for " - f"'{type(key).__name__}'.") - - return "_".join(_rec_str(key) for key in keys) - - def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...], kwargs: Mapping[str, Any] ) -> \ @@ -322,6 +300,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py new file mode 100644 index 00000000..03a3ccdc --- /dev/null +++ b/arraycontext/impl/pytato/outline.py @@ -0,0 +1,257 @@ +from __future__ import annotations + + +__doc__ = """ +.. autoclass:: OutlinedCall +""" +__copyright__ = """ +Copyright (C) 2023-5 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import itertools +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import Any + +import numpy as np +from immutabledict import immutabledict + +import pytato as pt +from pytools.tag import Tag + +from arraycontext.container import is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.context import ArrayOrContainer +from arraycontext.impl.pytato import _BasePytatoArrayContext + + +def _get_arg_id_to_arg(args: tuple[Any, ...], + kwargs: Mapping[str, Any] + ) -> immutabledict[tuple[Any, ...], Any]: + """ + Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id + to argument values. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + arg_id_to_arg: dict[tuple[Any, ...], Any] = {} + + for kw, arg in itertools.chain(enumerate(args), + kwargs.items()): + if arg is None: + pass + elif np.isscalar(arg): + # do not make scalars as placeholders since we inline them. + pass + elif is_array_container_type(arg.__class__): + def id_collector(keys, ary): + if np.isscalar(ary): + pass + else: + arg_id = (kw, *keys) # noqa: B023 + arg_id_to_arg[arg_id] = ary + return ary + + rec_keyed_map_array_container(id_collector, arg) + elif isinstance(arg, pt.Array): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar, pt.Array or an array container. Got" + f" '{arg}'.") + + return immutabledict(arg_id_to_arg) + + +def _get_input_arg_id_str( + arg_id: tuple[Any, ...], prefix: str | None = None) -> str: + if prefix is None: + prefix = "" + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_{prefix}_in_{_ary_container_key_stringifier(arg_id)}" + + +def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str: + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_out_{_ary_container_key_stringifier(arg_id)}" + + +def _get_arg_id_to_placeholder( + arg_id_to_arg: Mapping[tuple[Any, ...], Any], + prefix: str | None = None) -> immutabledict[tuple[Any, ...], pt.Placeholder]: + """ + Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder` + for each argument in *arg_id_to_arg*. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + return immutabledict({ + arg_id: pt.make_placeholder( + _get_input_arg_id_str(arg_id, prefix=prefix), + arg.shape, + arg.dtype) + for arg_id, arg in arg_id_to_arg.items()}) + + +def _call_with_placeholders( + f: Callable[..., Any], + args: tuple[Any], + kwargs: Mapping[str, Any], + arg_id_to_placeholder: Mapping[tuple[Any, ...], pt.Placeholder]) -> Any: + """ + Construct placeholders analogous to *args* and *kwargs* and call *f*. + """ + def get_placeholder_replacement(arg, key): + if arg is None: + return None + elif np.isscalar(arg): + return arg + elif isinstance(arg, pt.Array): + return arg_id_to_placeholder[key] + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + return get_placeholder_replacement(ary, key + keys) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + pl_args = [get_placeholder_replacement(arg, (iarg,)) + for iarg, arg in enumerate(args)] + pl_kwargs = {kw: get_placeholder_replacement(arg, (kw,)) + for kw, arg in kwargs.items()} + + return f(*pl_args, **pl_kwargs) + + +def _unpack_output( + output: ArrayOrContainer) -> pt.Array | dict[str, pt.Array]: + """Unpack any array containers in *output*.""" + if isinstance(output, pt.Array): + return output + elif is_array_container_type(output.__class__): + unpacked_output = {} + + def _unpack_container(key, ary): + key = _get_output_arg_id_str(key) + unpacked_output[key] = ary + return ary + + rec_keyed_map_array_container(_unpack_container, output) + + return unpacked_output + else: + raise NotImplementedError(type(output)) + + +def _pack_output( + output_template: ArrayOrContainer, + unpacked_output: pt.Array | tuple[pt.Array, ...] | Mapping[str, pt.Array] + ) -> ArrayOrContainer: + """ + Pack *unpacked_output* into array containers according to *output_template*. + """ + if isinstance(output_template, pt.Array): + return unpacked_output + elif is_array_container_type(output_template.__class__): + def _pack_into_container(key, ary): + key = _get_output_arg_id_str(key) + return unpacked_output[key] + + return rec_keyed_map_array_container(_pack_into_container, output_template) + else: + raise NotImplementedError(type(output_template)) + + +@dataclass(frozen=True) +class OutlinedCall: + actx: _BasePytatoArrayContext + f: Callable[..., Any] + tags: frozenset[Tag] + + def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer: + arg_id_to_arg = _get_arg_id_to_arg(args, kwargs) + + if __debug__: + # Add a prefix to the names to distinguish them from any existing + # placeholders + arg_id_to_prefixed_placeholder = _get_arg_id_to_placeholder( + arg_id_to_arg, prefix="outlined_call") + + prefixed_output = _call_with_placeholders( + self.f, args, kwargs, arg_id_to_prefixed_placeholder) + unpacked_prefixed_output = _unpack_output(prefixed_output) + if isinstance(unpacked_prefixed_output, pt.Array): + unpacked_prefixed_output = {"_": unpacked_prefixed_output} + + prefixed_placeholders = frozenset( + arg_id_to_prefixed_placeholder.values()) + + found_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()( + pt.make_dict_of_named_arrays(unpacked_prefixed_output)) + if isinstance(arg, pt.Placeholder)}) + + extra_placeholders = found_placeholders - prefixed_placeholders + assert not extra_placeholders, \ + "Found non-argument placeholder " \ + f"'{next(iter(extra_placeholders)).name}' in outlined function." + + arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg) + + output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder) + unpacked_output = _unpack_output(output) + if isinstance(unpacked_output, pt.Array): + unpacked_output = {"_": unpacked_output} + ret_type = pt.function.ReturnType.ARRAY + else: + ret_type = pt.function.ReturnType.DICT_OF_ARRAYS + + used_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()( + pt.make_dict_of_named_arrays(unpacked_output)) + if isinstance(arg, pt.Placeholder)}) + + call_bindings = { + placeholder.name: arg_id_to_arg[arg_id] + for arg_id, placeholder in arg_id_to_placeholder.items() + if placeholder in used_placeholders} + + # pylint-disable-reason: pylint has a hard time with kw_only fields in + # dataclasses + + # pylint: disable=unexpected-keyword-arg + func_def = pt.function.FunctionDefinition( + parameters=frozenset(call_bindings.keys()), + return_type=ret_type, + returns=immutabledict(unpacked_output), + tags=self.tags, + ) + + call_site_output = func_def(**call_bindings) + + return _pack_output(output, call_site_output) + + +# vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 2457e297..919a16fa 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -35,6 +35,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, cast +from pytato.analysis import get_num_call_sites from pytato.array import ( AbstractResultWithNamedArrays, Array, @@ -45,6 +46,7 @@ SizeParam, make_placeholder, ) +from pytato.function import FunctionDefinition from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.transform import ArrayOrNames, CopyMapper from pytools import UniqueNameGenerator, memoize_method @@ -95,7 +97,14 @@ def map_placeholder(self, expr: Placeholder) -> Array: raise ValueError("Placeholders cannot appear in" " DatawrapperToBoundPlaceholderMapper.") + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + raise ValueError("Function definitions cannot appear in" + " DatawrapperToBoundPlaceholderMapper.") + +# FIXME: This strategy doesn't work if the DAG has functions, since function +# definitions can't contain non-argument placeholders def _normalize_pt_expr( expr: DictOfNamedArrays ) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]: @@ -108,6 +117,11 @@ def _normalize_pt_expr( Deterministic naming of placeholders permits more effective caching of equivalent graphs. """ + if get_num_call_sites(expr): + raise NotImplementedError( + "_normalize_pt_expr is not compatible with expressions that " + "contain function calls.") + normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalized_expr = normalize_mapper(expr) assert isinstance(normalized_expr, AbstractResultWithNamedArrays) @@ -221,4 +235,30 @@ def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: # }}} + +# {{{ compile/outline helpers + +def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: + """ + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an + array-container's component's key. Goals of this routine: + + * No two different keys should have the same stringification + * Stringified key must a valid identifier according to :meth:`str.isidentifier` + * (informal) Shorter identifiers are preferred + """ + def _rec_str(key: Any) -> str: + if isinstance(key, str | int): + return str(key) + elif isinstance(key, tuple): + # t in '_actx_t': stands for tuple + return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" + else: + raise NotImplementedError("Key-stringication unimplemented for " + f"'{type(key).__name__}'.") + + return "_".join(_rec_str(key) for key in keys) + +# }}} + # vim: foldmethod=marker diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py new file mode 100644 index 00000000..bb04968d --- /dev/null +++ b/examples/how_to_outline.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import dataclasses as dc + +import numpy as np + +import pytato as pt +from pytools.obj_array import make_obj_array + +from arraycontext import ( + Array, + PytatoJAXArrayContext as BasePytatoJAXArrayContext, + dataclass_array_container, + with_container_arithmetic, +) + + +Ncalls = 300 + + +class PytatoJAXArrayContext(BasePytatoJAXArrayContext): + def transform_dag(self, dag): + # Test 1: Test that the number of untransformed call sites are as + # expected + assert pt.analysis.get_num_call_sites(dag) == Ncalls + + dag = pt.tag_all_calls_to_be_inlined(dag) + # FIXME: Re-enable this when concatenation is added to pytato + # print("[Pre-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(pt.inline_calls(dag))) + # dag = pt.concatenate_calls( + # dag, + # lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags + # ) + # + # # Test 2: Test that only one call-sites is left post concatenation + # assert pt.analysis.get_num_call_sites(dag) == 1 + # + # dag = pt.inline_calls(dag) + # print("[Post-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(dag)) + dag = pt.inline_calls(dag) + + return dag + + +actx = PytatoJAXArrayContext() + + +@with_container_arithmetic( + bcast_obj_array=True, + eq_comparison=False, + rel_comparison=False, +) +@dataclass_array_container +@dc.dataclass(frozen=True) +class State: + mass: Array | np.ndarray + vel: np.ndarray # np array of Arrays or numpy arrays + + +@actx.outline +def foo(x1, x2): + return (2*x1 + 3*x2 + x1**3 + x2**4 + + actx.np.minimum(2*x1, 4*x2) + + actx.np.maximum(7*x1, 8*x2)) + + +rng = np.random.default_rng(0) +Ndof = 10 +Ndim = 3 + +results = [] + +for _ in range(Ncalls): + Nel = rng.integers(low=4, high=17) + state1_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + state2_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + + state1 = actx.from_numpy(state1_np) + state2 = actx.from_numpy(state2_np) + results.append(foo(state1, state2)) + +actx.to_numpy(make_obj_array(results)) diff --git a/test/test_utils.py b/test/test_utils.py index eeef7723..a44c6035 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -39,7 +39,7 @@ # {{{ test_pt_actx_key_stringification_uniqueness def test_pt_actx_key_stringification_uniqueness(): - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier assert (_ary_container_key_stringifier(((3, 2), 3)) != _ary_container_key_stringifier((3, (2, 3))))