Skip to content

Commit

Permalink
Introduce Bcast object-ified broacasting rules
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 11, 2024
1 parent d8e8683 commit f37f298
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 21 deletions.
12 changes: 12 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
serialize_container,
)
from .container.arithmetic import (
Bcast,
Bcast1Level,
Bcast2Levels,
Bcast3Levels,
BcastNLevels,
BcastUntilActxArray,
with_container_arithmetic,
)
from .container.dataclass import dataclass_array_container
Expand Down Expand Up @@ -115,6 +121,12 @@
"ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"ArrayT",
"Bcast",
"Bcast1Level",
"Bcast2Levels",
"Bcast3Levels",
"BcastNLevels",
"BcastUntilActxArray",
"CommonSubexpressionTag",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
Expand Down
241 changes: 222 additions & 19 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
.. currentmodule:: arraycontext
.. autofunction:: with_container_arithmetic
.. autoclass:: Bcast
.. autoclass:: BcastNLevels
.. autoclass:: BcastUntilActxArray
.. function:: Bcast1
Like :class:`BcastNLevels` with *nlevels* set to 1.
.. function:: Bcast2
Like :class:`BcastNLevels` with *nlevels* set to 2.
.. function:: Bcast3
Like :class:`BcastNLevels` with *nlevels* set to 3.
"""


Expand Down Expand Up @@ -34,12 +49,18 @@
"""

import enum
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, TypeVar
from dataclasses import FrozenInstanceError
from functools import partial
from numbers import Number
from typing import Any, ClassVar, TypeVar, Union
from warnings import warn

import numpy as np

from arraycontext.context import ArrayContext, ArrayOrContainer


# {{{ with_container_arithmetic

Expand Down Expand Up @@ -142,8 +163,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
warn(
"Broadcasting container against non-object numpy array. "
"This was never documented to work and will now stop working in "
"2025. Convert the array to an object array to preserve the "
"current semantics.", DeprecationWarning, stacklevel=3)
"2025. Convert the array to an object array or use "
"variants of arraycontext.Bcast to obtain the desired "
"broadcasting semantics.", DeprecationWarning, stacklevel=3)
return True
else:
return False
Expand All @@ -153,6 +175,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
pass


class Bcast:
"""
A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
to broadcast *arg* across a container (with the container as the 'outer' structure).
Since array containers are often nested in complex ways, different subclasses
implement different rules on how broadcasting interacts with the hierarchy,
with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
the most common.
"""
arg: ArrayOrContainer

# Accessing this attribute is cheaper than isinstance, so use that
# to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
_with_next_operand: ClassVar[bool]

def __init__(self, arg: ArrayOrContainer) -> None:
object.__setattr__(self, "arg", arg)

def __setattr__(self, name: str, value: Any) -> None:
raise FrozenInstanceError()

def __delattr__(self, name: str) -> None:
raise FrozenInstanceError()


class _BcastWithNextOperand(Bcast, ABC):
"""
A :class:`Bcast` object that gets to see who the next operand will be, in
order to decide whether wrapping the child in :class:`Bcast` is still necessary.
This is much more flexible, but also considerably more expensive, than
:class:`_BcastWithoutNextOperand`.
"""

_with_next_operand = True

# purposefully undocumented
@abstractmethod
def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer:
...


class _BcastWithoutNextOperand(Bcast, ABC):
"""
A :class:`Bcast` object that does not get to see who the next operand will be.
"""
_with_next_operand = False

# purposefully undocumented
@abstractmethod
def _rewrap(self) -> ArrayOrContainer:
...


class BcastNLevels(_BcastWithoutNextOperand):
"""
A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
convenient aliases for the common cases.
Usage example::
container + Bcast2(actx_array)
.. note::
:mod:`numpy` object arrays do not count against the number of levels.
.. automethod:: __init__
"""
nlevels: int

def __init__(self, nlevels: int, arg: ArrayOrContainer) -> None:
if nlevels < 1:
raise ValueError("nlevels is expected to be one or greater.")

super().__init__(arg)
object.__setattr__(self, "nlevels", nlevels)

def _rewrap(self) -> ArrayOrContainer:
if self.nlevels == 1:
return self.arg
else:
return BcastNLevels(self.nlevels-1, self.arg)


Bcast1Level = partial(BcastNLevels, 1)
Bcast2Levels = partial(BcastNLevels, 2)
Bcast3Levels = partial(BcastNLevels, 3)


class BcastUntilActxArray(_BcastWithNextOperand):
"""
A broadcast rule that broadcasts *arg* across array containers until
the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
of *actx*, or a :class:`~numbers.Number`.
Suggested usage pattern::
bcast = functools.partial(BcastUntilActxArray, actx)
container + bcast(actx_array)
.. automethod:: __init__
"""
actx: ArrayContext

def __init__(self,
actx: ArrayContext,
arg: ArrayOrContainer) -> None:
super().__init__(arg)
object.__setattr__(self, "actx", actx)

def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer:
if isinstance(other_operand, (*self.actx.array_types, Number)):
return self.arg
else:
return self


def with_container_arithmetic(
*,
number_bcasts_across: bool | None = None,
Expand Down Expand Up @@ -207,6 +348,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
Each operator class also includes the "reverse" operators if applicable.
.. note::
For the generated binary arithmetic operators, if certain types
should be broadcast over the container (with the container as the
'outer' structure) but are not handled in this way by their types,
you may wrap them in one of the :class:`Bcast` variants to achieve
the desired semantics.
.. note::
To generate the code implementing the operators, this function relies on
Expand Down Expand Up @@ -239,6 +388,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
#
# - Broadcast rules are hard to change once established, particularly
# because one cannot grep for their use.
#
# Possible advantages of the "Bcast" broadcast-rule-as-object design:
#
# - If one rule does not fit the user's need, they can straightforwardly use
# another.
#
# - It's straightforward to find where certain broadcast rules are used.
#
# - The broadcast rule can contain more state. For example, it's now easy
# for the rule to know what array context should be used to determine
# actx array types.
#
# Possible downsides of the "Bcast" broadcast-rule-as-object design:
#
# - User code is a bit more wordy.
#
# - Rewrapping has the potential to be costly, especially in
# _with_next_operand mode.

# {{{ handle inputs

Expand Down Expand Up @@ -404,9 +571,8 @@ def wrap(cls: Any) -> Any:
f"Broadcasting array context array types across {cls} "
"has been explicitly "
"enabled. As of 2025, this will stop working. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"Express these operations using arraycontext.Bcast variants "
"instead. "
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False. ",
DeprecationWarning, stacklevel=2)
Expand All @@ -415,9 +581,8 @@ def wrap(cls: Any) -> Any:
f"Broadcasting array context array types across {cls} "
"has been implicitly "
"enabled. As of 2025, this will no longer work. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"Express these operations using arraycontext.Bcast variants "
"instead. "
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False.",
DeprecationWarning, stacklevel=2)
Expand All @@ -435,7 +600,7 @@ def wrap(cls: Any) -> Any:
gen(f"""
from numbers import Number
import numpy as np
from arraycontext import ArrayContainer
from arraycontext import ArrayContainer, Bcast
from warnings import warn
def _raise_if_actx_none(actx):
Expand All @@ -455,7 +620,8 @@ def is_numpy_array(arg):
"behavior will change in 2025. If you would like the "
"broadcasting behavior to stay the same, make sure "
"to convert the passed numpy array to an "
"object array.",
"object array, or use arraycontext.Bcast to achieve "
"the desired broadcasting semantics.",
DeprecationWarning, stacklevel=3)
return True
else:
Expand Down Expand Up @@ -553,6 +719,33 @@ def {fname}(arg1):
cls._serialize_init_arrays_code("arg2").items()
})

def get_operand(arg: Union[tuple[str, str], str]) -> str:
if isinstance(arg, tuple):
entry, _container = arg
return entry
else:
return arg

bcast_init_args_arg1_is_outer_with_rewrap = \
cls._deserialize_init_arrays_code("arg1", {
key_arg1:
_format_binary_op_str(
op_str, expr_arg1,
f"arg2._rewrap({get_operand(expr_arg1)})")
for key_arg1, expr_arg1 in
cls._serialize_init_arrays_code("arg1").items()
})
bcast_init_args_arg2_is_outer_with_rewrap = \
cls._deserialize_init_arrays_code("arg2", {
key_arg2:
_format_binary_op_str(
op_str,
f"arg1._rewrap({get_operand(expr_arg2)})",
expr_arg2)
for key_arg2, expr_arg2 in
cls._serialize_init_arrays_code("arg2").items()
})

# {{{ "forward" binary operators

gen(f"def {fname}(arg1, arg2):")
Expand Down Expand Up @@ -605,14 +798,19 @@ def {fname}(arg1):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg2)}} is deprecated "
"and will no longer work in 2025. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
"Use arraycontext.Bcast to achieve the desired "
"broadcasting semantics.",
DeprecationWarning, stacklevel=2)
return cls({bcast_init_args_arg1_is_outer})
if isinstance(arg2, Bcast):
if arg2._with_next_operand:
return cls({bcast_init_args_arg1_is_outer_with_rewrap})
else:
arg2 = arg2._rewrap()
return cls({bcast_init_args_arg1_is_outer})
return NotImplemented
""")
gen(f"cls.__{dunder_name}__ = {fname}")
Expand Down Expand Up @@ -656,14 +854,19 @@ def {fname}(arg2, arg1):
f"context array type {{type(arg1)}} "
"is deprecated "
"and will no longer work in 2025."
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
"Use arraycontext.Bcast to achieve the "
"desired broadcasting semantics.",
DeprecationWarning, stacklevel=2)
return cls({bcast_init_args_arg2_is_outer})
if isinstance(arg1, Bcast):
if arg1._with_next_operand:
return cls({bcast_init_args_arg2_is_outer_with_rewrap})
else:
arg1 = arg1._rewrap()
return cls({bcast_init_args_arg2_is_outer})
return NotImplemented
cls.__r{dunder_name}__ = {fname}""")
Expand Down
Loading

0 comments on commit f37f298

Please sign in to comment.