diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 674a229d..58b76a3f 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -46,6 +46,12 @@ serialize_container, ) from .container.arithmetic import ( + Bcast, + Bcast1Level, + Bcast2Levels, + Bcast3Levels, + BcastNLevels, + BcastUntilActxArray, with_container_arithmetic, ) from .container.dataclass import dataclass_array_container @@ -115,6 +121,12 @@ "ArrayOrContainerOrScalarT", "ArrayOrContainerT", "ArrayT", + "Bcast", + "Bcast1Level", + "Bcast2Levels", + "Bcast3Levels", + "BcastNLevels", + "BcastUntilActxArray", "CommonSubexpressionTag", "EagerJAXArrayContext", "ElementwiseMapKernelTag", diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 22572dc8..2e55bce5 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -6,6 +6,21 @@ .. currentmodule:: arraycontext .. autofunction:: with_container_arithmetic +.. autoclass:: Bcast +.. autoclass:: BcastNLevels +.. autoclass:: BcastUntilActxArray + +.. function:: Bcast1 + + Like :class:`BcastNLevels` with *nlevels* set to 1. + +.. function:: Bcast2 + + Like :class:`BcastNLevels` with *nlevels* set to 2. + +.. function:: Bcast3 + + Like :class:`BcastNLevels` with *nlevels* set to 3. """ @@ -34,12 +49,18 @@ """ import enum +from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any, TypeVar +from dataclasses import FrozenInstanceError +from functools import partial +from numbers import Number +from typing import Any, ClassVar, TypeVar, Union from warnings import warn import numpy as np +from arraycontext.context import ArrayContext, ArrayOrContainer + # {{{ with_container_arithmetic @@ -142,8 +163,9 @@ def __instancecheck__(cls, instance: Any) -> bool: warn( "Broadcasting container against non-object numpy array. " "This was never documented to work and will now stop working in " - "2025. Convert the array to an object array to preserve the " - "current semantics.", DeprecationWarning, stacklevel=3) + "2025. Convert the array to an object array or use " + "variants of arraycontext.Bcast to obtain the desired " + "broadcasting semantics.", DeprecationWarning, stacklevel=3) return True else: return False @@ -153,6 +175,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet pass +class Bcast: + """ + A wrapper object to force arithmetic generated by :func:`with_container_arithmetic` + to broadcast *arg* across a container (with the container as the 'outer' structure). + Since array containers are often nested in complex ways, different subclasses + implement different rules on how broadcasting interacts with the hierarchy, + with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing + the most common. + """ + arg: ArrayOrContainer + + # Accessing this attribute is cheaper than isinstance, so use that + # to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand. + _with_next_operand: ClassVar[bool] + + def __init__(self, arg: ArrayOrContainer) -> None: + object.__setattr__(self, "arg", arg) + + def __setattr__(self, name: str, value: Any) -> None: + raise FrozenInstanceError() + + def __delattr__(self, name: str) -> None: + raise FrozenInstanceError() + + +class _BcastWithNextOperand(Bcast, ABC): + """ + A :class:`Bcast` object that gets to see who the next operand will be, in + order to decide whether wrapping the child in :class:`Bcast` is still necessary. + This is much more flexible, but also considerably more expensive, than + :class:`_BcastWithoutNextOperand`. + """ + + _with_next_operand = True + + # purposefully undocumented + @abstractmethod + def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer: + ... + + +class _BcastWithoutNextOperand(Bcast, ABC): + """ + A :class:`Bcast` object that does not get to see who the next operand will be. + """ + _with_next_operand = False + + # purposefully undocumented + @abstractmethod + def _rewrap(self) -> ArrayOrContainer: + ... + + +class BcastNLevels(_BcastWithoutNextOperand): + """ + A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of + array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as + convenient aliases for the common cases. + + Usage example:: + + container + Bcast2(actx_array) + + .. note:: + + :mod:`numpy` object arrays do not count against the number of levels. + + .. automethod:: __init__ + """ + nlevels: int + + def __init__(self, nlevels: int, arg: ArrayOrContainer) -> None: + if nlevels < 1: + raise ValueError("nlevels is expected to be one or greater.") + + super().__init__(arg) + object.__setattr__(self, "nlevels", nlevels) + + def _rewrap(self) -> ArrayOrContainer: + if self.nlevels == 1: + return self.arg + else: + return BcastNLevels(self.nlevels-1, self.arg) + + +Bcast1Level = partial(BcastNLevels, 1) +Bcast2Levels = partial(BcastNLevels, 2) +Bcast3Levels = partial(BcastNLevels, 3) + + +class BcastUntilActxArray(_BcastWithNextOperand): + """ + A broadcast rule that broadcasts *arg* across array containers until + the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types` + of *actx*, or a :class:`~numbers.Number`. + + Suggested usage pattern:: + + bcast = functools.partial(BcastUntilActxArray, actx) + + container + bcast(actx_array) + + .. automethod:: __init__ + """ + actx: ArrayContext + + def __init__(self, + actx: ArrayContext, + arg: ArrayOrContainer) -> None: + super().__init__(arg) + object.__setattr__(self, "actx", actx) + + def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer: + if isinstance(other_operand, (*self.actx.array_types, Number)): + return self.arg + else: + return self + + def with_container_arithmetic( *, number_bcasts_across: bool | None = None, @@ -207,6 +348,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__` Each operator class also includes the "reverse" operators if applicable. + .. note:: + + For the generated binary arithmetic operators, if certain types + should be broadcast over the container (with the container as the + 'outer' structure) but are not handled in this way by their types, + you may wrap them in one of the :class:`Bcast` variants to achieve + the desired semantics. + .. note:: To generate the code implementing the operators, this function relies on @@ -239,6 +388,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): # # - Broadcast rules are hard to change once established, particularly # because one cannot grep for their use. + # + # Possible advantages of the "Bcast" broadcast-rule-as-object design: + # + # - If one rule does not fit the user's need, they can straightforwardly use + # another. + # + # - It's straightforward to find where certain broadcast rules are used. + # + # - The broadcast rule can contain more state. For example, it's now easy + # for the rule to know what array context should be used to determine + # actx array types. + # + # Possible downsides of the "Bcast" broadcast-rule-as-object design: + # + # - User code is a bit more wordy. + # + # - Rewrapping has the potential to be costly, especially in + # _with_next_operand mode. # {{{ handle inputs @@ -404,9 +571,8 @@ def wrap(cls: Any) -> Any: f"Broadcasting array context array types across {cls} " "has been explicitly " "enabled. As of 2025, this will stop working. " - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/pull/190. " + "Express these operations using arraycontext.Bcast variants " + "instead. " "To opt out now (and avoid this warning), " "pass _bcast_actx_array_type=False. ", DeprecationWarning, stacklevel=2) @@ -415,9 +581,8 @@ def wrap(cls: Any) -> Any: f"Broadcasting array context array types across {cls} " "has been implicitly " "enabled. As of 2025, this will no longer work. " - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/pull/190. " + "Express these operations using arraycontext.Bcast variants " + "instead. " "To opt out now (and avoid this warning), " "pass _bcast_actx_array_type=False.", DeprecationWarning, stacklevel=2) @@ -435,7 +600,7 @@ def wrap(cls: Any) -> Any: gen(f""" from numbers import Number import numpy as np - from arraycontext import ArrayContainer + from arraycontext import ArrayContainer, Bcast from warnings import warn def _raise_if_actx_none(actx): @@ -455,7 +620,8 @@ def is_numpy_array(arg): "behavior will change in 2025. If you would like the " "broadcasting behavior to stay the same, make sure " "to convert the passed numpy array to an " - "object array.", + "object array, or use arraycontext.Bcast to achieve " + "the desired broadcasting semantics.", DeprecationWarning, stacklevel=3) return True else: @@ -553,6 +719,33 @@ def {fname}(arg1): cls._serialize_init_arrays_code("arg2").items() }) + def get_operand(arg: Union[tuple[str, str], str]) -> str: + if isinstance(arg, tuple): + entry, _container = arg + return entry + else: + return arg + + bcast_init_args_arg1_is_outer_with_rewrap = \ + cls._deserialize_init_arrays_code("arg1", { + key_arg1: + _format_binary_op_str( + op_str, expr_arg1, + f"arg2._rewrap({get_operand(expr_arg1)})") + for key_arg1, expr_arg1 in + cls._serialize_init_arrays_code("arg1").items() + }) + bcast_init_args_arg2_is_outer_with_rewrap = \ + cls._deserialize_init_arrays_code("arg2", { + key_arg2: + _format_binary_op_str( + op_str, + f"arg1._rewrap({get_operand(expr_arg2)})", + expr_arg2) + for key_arg2, expr_arg2 in + cls._serialize_init_arrays_code("arg2").items() + }) + # {{{ "forward" binary operators gen(f"def {fname}(arg1, arg2):") @@ -605,14 +798,19 @@ def {fname}(arg1): warn("Broadcasting {cls} over array " f"context array type {{type(arg2)}} is deprecated " "and will no longer work in 2025. " - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/" - "pull/190. ", + "Use arraycontext.Bcast to achieve the desired " + "broadcasting semantics.", DeprecationWarning, stacklevel=2) return cls({bcast_init_args_arg1_is_outer}) + if isinstance(arg2, Bcast): + if arg2._with_next_operand: + return cls({bcast_init_args_arg1_is_outer_with_rewrap}) + else: + arg2 = arg2._rewrap() + return cls({bcast_init_args_arg1_is_outer}) + return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -656,14 +854,19 @@ def {fname}(arg2, arg1): f"context array type {{type(arg1)}} " "is deprecated " "and will no longer work in 2025." - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/" - "pull/190. ", + "Use arraycontext.Bcast to achieve the " + "desired broadcasting semantics.", DeprecationWarning, stacklevel=2) return cls({bcast_init_args_arg2_is_outer}) + if isinstance(arg1, Bcast): + if arg1._with_next_operand: + return cls({bcast_init_args_arg2_is_outer_with_rewrap}) + else: + arg1 = arg1._rewrap() + return cls({bcast_init_args_arg2_is_outer}) + return NotImplemented cls.__r{dunder_name}__ = {fname}""") diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ab263304..13a19600 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -34,6 +34,9 @@ from pytools.tag import Tag from arraycontext import ( + Bcast1Level, + Bcast2Levels, + BcastUntilActxArray, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, @@ -1194,7 +1197,7 @@ def test_container_equality(actx_factory): # }}} -# {{{ test_no_leaf_array_type_broadcasting +# {{{ test_leaf_array_type_broadcasting def test_no_leaf_array_type_broadcasting(actx_factory): from testlib import Foo @@ -1203,14 +1206,85 @@ def test_no_leaf_array_type_broadcasting(actx_factory): dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )) foo = Foo(dof_ary) + bar = foo + 4 + + bcast = partial(BcastUntilActxArray, actx) actx_ary = actx.from_numpy(4*np.ones((3, ))) with pytest.raises(TypeError): foo + actx_ary + baz = foo + Bcast2Levels(actx_ary) + qux = Bcast2Levels(actx_ary) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + baz = foo + bcast(actx_ary) + qux = bcast(actx_ary) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + mc = MyContainer( + name="hi", + mass=dof_ary, + momentum=make_obj_array([dof_ary, dof_ary]), + enthalpy=dof_ary) + with pytest.raises(TypeError): - foo + actx.from_numpy(np.array(4)) + mc_op = mc + actx_ary + + mom_op = mc.momentum + Bcast1Level(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0])) + + mc_op = mc + Bcast2Levels(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + mom_op = mc.momentum + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0])) + + mc_op = mc + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + + def _actx_allows_scalar_broadcast(actx): + if not isinstance(actx, PyOpenCLArrayContext): + return True + else: + import pyopencl as cl + + # See https://github.com/inducer/pyopencl/issues/498 + return cl.version.VERSION > (2021, 2, 5) + + if _actx_allows_scalar_broadcast(actx): + with pytest.raises(TypeError): + foo + actx.from_numpy(np.array(4)) + + quuz = Bcast2Levels(actx.from_numpy(np.array(4))) + foo + quux = foo + Bcast2Levels(actx.from_numpy(np.array(4))) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) + + quuz = bcast(actx.from_numpy(np.array(4))) + foo + quux = foo + bcast(actx.from_numpy(np.array(4))) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) # }}} @@ -1279,6 +1353,15 @@ def equal(a, b): b_bcast_dc_of_dofs.momentum), enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy)) + # Array context scalars + two = actx.from_numpy(np.array(2)) + assert equal( + outer(Bcast2Levels(two), b_bcast_dc_of_dofs), + Bcast2Levels(two)*b_bcast_dc_of_dofs) + assert equal( + outer(a_bcast_dc_of_dofs, Bcast2Levels(two)), + a_bcast_dc_of_dofs*Bcast2Levels(two)) + # }}}