From d9c94fecabff05b2214dc4e5f85ae0b35afdb900 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 18:48:17 -0500 Subject: [PATCH 1/7] Deprecate with_container_arithmetic's bcast_numpy_array arg Passing both 'bcast_numpy_array' and '_bcast_actx_array_types' was ill-defined. For example, in the case of an ArrayContext whose thawed array type is np.ndarray the specification would contradict between broadcasting the argument numpy_array to return an object array *OR* peforming the operation with every leaf array. Consider the example below, ( - 'Foo: ArrayContainer' whose arithmetic routines are generated by `with_container_arithmetic(bcast_numpy=True, _bcast_actx_array_types=True)` - 'actx: ArrayContextT' for whom `np.ndarray` is a valid thawed array type. ) Foo(DOFArray(actx, [38*actx.ones(3, np.float64)])) + np.array([3, 4, 5]) could be either of: - array([Foo(DOFArray([array([41, 41, 41])])), Foo(DOFArray([array([42, 42, 42])])), Foo(DOFArray([array([43, 43, 43])]))]), OR, - Foo(DOFArray(actx, array([41, 42, 43]))) --- arraycontext/container/arithmetic.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 5e2ade2e..6d281a34 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -213,6 +213,15 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if rel_comparison is None: raise TypeError("rel_comparison must be specified") + if bcast_numpy_array: + from warnings import warn + warn("'bcast_numpy_array=True' is deprecated and will be unsupported" + " from December 2021", DeprecationWarning, stacklevel=2) + + if _bcast_actx_array_type: + raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" + " cannot be both set.") + if rel_comparison and eq_comparison is None: eq_comparison = True From c32e4e11c3d0df9b1671d702c53329d2159e9e83 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 02:38:28 -0500 Subject: [PATCH 2/7] Implements NumpyArrayContext --- arraycontext/__init__.py | 3 + arraycontext/impl/numpy/__init__.py | 124 ++++++++++++++++++++++ arraycontext/impl/numpy/fake_numpy.py | 142 ++++++++++++++++++++++++++ 3 files changed, 269 insertions(+) create mode 100644 arraycontext/impl/numpy/__init__.py create mode 100644 arraycontext/impl/numpy/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 06e0b96c..cf3c961e 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -76,6 +76,7 @@ from .impl.pytato import (PytatoPyOpenCLArrayContext, PytatoJAXArrayContext) from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext from .pytest import ( PytestArrayContextFactory, @@ -123,6 +124,8 @@ "PytatoJAXArrayContext", "EagerJAXArrayContext", + "NumpyArrayContext", + "make_loopy_program", "PytestArrayContextFactory", diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 00000000..76988856 --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,124 @@ +""" +.. currentmodule:: arraycontext + + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from arraycontext.context import ArrayContext +import numpy as np +import loopy as lp +from typing import Union, Sequence, Dict +from pytools.tag import Tag + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays + + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + + self.array_types = (np.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return np.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def from_numpy(self, np_array: np.ndarray): + # Uh oh... + return np_array + + def to_numpy(self, array): + # Uh oh... + return array + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + return array + + def thaw(self, array): + return array + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("NumpyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags: Union[Sequence[Tag], Tag], array): + # Numpy doesn't support tagging + return array + + def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 00000000..a15c5e9b --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,142 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +from arraycontext.fake_numpy import ( + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, + ) +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + rec_map_array_container, + rec_multimap_array_container, + multimap_reduce_array_container, + rec_map_reduce_array_container, + rec_multimap_reduce_array_container, + ) +import numpy as np + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", + "sqrt", "exp", "concatenate", "reshape", "transpose", + "ones_like", "maximum", "minimum", "where", "conj", "arctan2", + } + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + return super().__getattr__(name) + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y, dtype=None): + if dtype is not None: + raise NotImplementedError("only 'dtype=None' supported.") + + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a, b): + if type(a) != type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return np.all(np.equal(a, b)) + else: + return multimap_reduce_array_container(partial(reduce, + np.logical_and), + self.array_equal, a, b) + +# vim: fdm=marker From 6e30532aa19814a05eb1b18911b130d19c84d5dc Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 03:03:53 -0500 Subject: [PATCH 3/7] ArrayContainer fixes for numpy arrays as leaf classes --- arraycontext/container/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 71bccee2..1f96e777 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -216,7 +216,11 @@ def is_array_container(ary: Any) -> bool: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # type:ignore[attr-defined] + # numpy values with scalar elements aren't array containers + and not (isinstance(ary, np.ndarray) + and ary.dtype.kind != "O") + ) @singledispatch From ce8ab7c5b1a41e3867cf6df3f775a556ad98da43 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 01:32:30 -0500 Subject: [PATCH 4/7] arithmetic fixes to account for np.ndarray being a leaf array --- arraycontext/container/arithmetic.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 6d281a34..1def8847 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -492,16 +492,17 @@ def {fname}(arg1): bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -538,16 +539,16 @@ def {fname}(arg1): def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_init_args}) return NotImplemented cls.__r{dunder_name}__ = {fname}""") From d22dddced58291377988c2e0ecde70e1c75cf4d5 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 02:41:25 -0500 Subject: [PATCH 5/7] test NumpyArrayContext --- arraycontext/pytest.py | 22 ++++++++++++++++++++++ test/test_arraycontext.py | 3 +++ 2 files changed, 25 insertions(+) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1eceb497..26535185 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -35,6 +35,7 @@ from typing import Any, Callable, Dict, Sequence, Type, Union from arraycontext.context import ArrayContext +from arraycontext import NumpyArrayContext # {{{ array context factories @@ -195,6 +196,26 @@ def __str__(self): return "" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -203,6 +224,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108e..0975d5ce 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -45,6 +45,8 @@ _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory) + _PytestPytatoPyOpenCLArrayContextFactory, + _PytestNumpyArrayContextFactory) import logging @@ -93,6 +95,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) From dba4f53af02896021a0a638f5c3ef853a67727b8 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 01:35:55 -0500 Subject: [PATCH 6/7] test tweaks for NumpyArrayContext --- test/test_arraycontext.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0975d5ce..3b95b38b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -44,7 +44,7 @@ from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory) + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, _PytestNumpyArrayContextFactory) @@ -1138,7 +1138,11 @@ def test_flatten_with_leaf_class(actx_factory): # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory): + from arraycontext import NumpyArrayContext + actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("Irrelevant tests for NumpyArrayContext") nelements = 42 ac = MyContainer( @@ -1317,6 +1321,8 @@ def test_container_equality(actx_factory): class Foo: u: DOFArray + __array_priority__ = 1 # disallow numpy arithmetic to take precedence + @property def array_context(self): return self.u.array_context From 958c769d0fac477240f8742e23d1c58638f0f720 Mon Sep 17 00:00:00 2001 From: Esteban Date: Tue, 20 Sep 2022 19:40:43 -0400 Subject: [PATCH 7/7] Added new Torch array context --- arraycontext/__init__.py | 3 +- arraycontext/impl/jax/fake_numpy.py | 4 +- arraycontext/impl/torch/__init__.py | 174 ++++++++++++++++++++++ arraycontext/impl/torch/fake_numpy.py | 200 ++++++++++++++++++++++++++ arraycontext/pytest.py | 21 +++ test/test_arraycontext.py | 14 +- 6 files changed, 410 insertions(+), 6 deletions(-) create mode 100644 arraycontext/impl/torch/__init__.py create mode 100644 arraycontext/impl/torch/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index cf3c961e..d25c070b 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -77,6 +77,7 @@ PytatoJAXArrayContext) from .impl.jax import EagerJAXArrayContext from .impl.numpy import NumpyArrayContext +from .impl.torch import TorchArrayContext from .pytest import ( PytestArrayContextFactory, @@ -124,7 +125,7 @@ "PytatoJAXArrayContext", "EagerJAXArrayContext", - "NumpyArrayContext", + "NumpyArrayContext", "TorchArrayContext", "make_loopy_program", diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 37c99b4a..03bb3ade 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -68,7 +68,7 @@ def _full_like(subary): # }}} - # {{{ array manipulation routies + # {{{ array manipulation routines def reshape(self, a, newshape, order="C"): return rec_map_array_container( @@ -133,7 +133,7 @@ def all(self, a): def any(self, a): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) - + def array_equal(self, a, b): actx = self._array_context diff --git a/arraycontext/impl/torch/__init__.py b/arraycontext/impl/torch/__init__.py new file mode 100644 index 00000000..362523ba --- /dev/null +++ b/arraycontext/impl/torch/__init__.py @@ -0,0 +1,174 @@ +""" +.. currentmodule:: arraycontext +.. autoclass:: TorchArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Callable, Optional, Tuple + +import numpy as np +from typing import Union, Sequence, Dict +from pytools.tag import Tag, ToTagSetConvertible +from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike +from arraycontext.container.traversal import (with_array_context, + rec_map_array_container) + + +class TorchArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :class:`torch.Tensor` instances for its base array class. + + + .. automethod:: __init__ + """ + def __init__(self) -> None: + super().__init__() + + from torch import Tensor + self.array_types = (Tensor, ) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import TorchFakeNumpyNamespace + return TorchFakeNumpyNamespace(self) + + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: + if allowed_types is None: + allowed_types = self.array_types + + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + # {{{ ArrayContext interface + + def empty(self, shape, dtype): + import torch + return torch.empty(shape, dtype) + + def zeros(self, shape, dtype): + import torch + return torch.zeros(shape, dtype=dtype) + + def empty_like(self, ary): + def _empty_like(array): + return self.empty(array.shape, array.dtype) + + return self._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + def _zeros_like(array): + return self.zeros(array.shape, array.dtype) + + return self._rec_map_container(_zeros_like, ary, default_scalar=0) + + def from_numpy(self, array: np.ndarray): + def _from_numpy(ary): + import torch + return torch.from_numpy(array) + + return with_array_context( + self._rec_map_container(_from_numpy, array, allowed_types=(np.ndarray,)), + actx=self) + + def to_numpy(self, array): + def _to_numpy(ary): + import torch + if array.dtype == torch.complex64 or array.dtype == torch.complex128: + return ary.resolve_conj().numpy() + else: + return ary.detach().numpy() + + return with_array_context( + self._rec_map_container(_to_numpy, array), + actx=None) + + def freeze(self, array): + def _freeze(ary): + return ary + return with_array_context(self._rec_map_container(_freeze, array), actx=None) + + def thaw(self, array): + return with_array_context(array, actx=self) + + def tag(self, tags: ToTagSetConvertible, array): + # Sorry, not capable. + return array + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + # Sorry, not capable. + return array + + def call_loopy(self, t_unit, **kwargs): + raise NotImplementedError( + "Calling loopy on Torch arrays is not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array operations using" + " ArrayContext.np.") + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import torch + if arg_names is not None: + from warnings import warn + warn("'arg_names' don't bear any significance in " + f"{type(self).__name__}.", stacklevel=2) + + return torch.einsum(spec, *args) + + def clone(self): + return type(self)() + + # }}} + + # {{{ properties + + @property + def permits_inplace_modification(self): + return False + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True + + # }}} diff --git a/arraycontext/impl/torch/fake_numpy.py b/arraycontext/impl/torch/fake_numpy.py new file mode 100644 index 00000000..9dd1139c --- /dev/null +++ b/arraycontext/impl/torch/fake_numpy.py @@ -0,0 +1,200 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np +import torch + +from arraycontext.fake_numpy import ( + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, + ) +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + rec_map_array_container, + rec_multimap_array_container, + multimap_reduce_array_container, + rec_map_reduce_array_container, + rec_multimap_reduce_array_container, + ) + + +class TorchFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +class TorchFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`~arraycontext.TorchArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return TorchFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + return partial(rec_multimap_array_container, getattr(torch, name)) + + # NOTE: the order of these follows the order in numpy docs + # NOTE: when adding a function here, also add it to `array_context.rst` docs! + + # {{{ array creation routines + + def ones_like(self, ary): + return self.full_like(ary, 1) + + def full_like(self, ary, fill_value): + def _full_like(subary): + return torch.full_like(subary, fill_value) + + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) + + # }}} + + # {{{ array manipulation routines + + def reshape(self, a, newshape, order="C"): + """ + .. warning:: + + Since :func:`torch.reshape` does not support orders `A`` and + ``K``, in such cases we fallback to using ``order = C``. + """ + if order in "AK": + from warnings import warn + warn(f"reshape with order='{order}' nor supported by Torch," + " using order=C.") + + return rec_map_array_container( + lambda ary: torch.reshape(ary, newshape), a + ) + + def ravel(self, a, order="C"): + """ + .. warning:: + + Since :func:`torch.reshape` does not support orders `A`` and + ``K``, in such cases we fallback to using ``order = C``. + """ + if order in "AK": + from warnings import warn + warn(f"reshape with order='{order}' nor supported by Torch," + " using order=C.") + + return rec_map_array_container( + lambda ary: torch.ravel(ary), a + ) + + def transpose(self, a, dim0=0, dim1=1): + return rec_multimap_array_container(torch.transpose, a, dim0, dim1) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(torch.broadcast_to, shape=shape), array) + + def concatenate(self, arrays, axis=0): + return rec_multimap_array_container(torch.cat, arrays, axis) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: torch.stack(tensors=args, dim=axis), *arrays) + + # {{{ linear algebra + + # }}} + + # {{{ logic functions + + def all(self, a): + return rec_map_reduce_array_container( + partial(reduce, torch.logical_and), torch.all, a) + + def any(self, a): + return rec_map_reduce_array_container( + partial(reduce, torch.logical_or), torch.any, a) + + def array_equal(self, a, b): + actx = self._array_context + + true = actx.from_numpy(np.int8(True)) + false = actx.from_numpy(np.int8(False)) + + def rec_equal(x, y): + if type(x) != type(y): + return false + + try: + iterable = zip(serialize_container(x), serialize_container(y)) + except NotAnArrayContainerError: + if x.shape != y.shape: + return false + else: + return torch.all(torch.equal(x, y)) + else: + return reduce( + torch.logical_and, + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true) + + return rec_equal(a, b) + + def equal(self, a, b): + # ECG: Really? + return a == b + + # }}} + + # {{{ mathematical functions + + def sum(self, a, axis=0, dtype=None): + return rec_map_reduce_array_container( + sum, + partial(torch.sum, axis=axis, dtype=dtype), + a) + + def amin(self, a, axis=0): + return rec_map_reduce_array_container( + partial(reduce, torch.minimum), partial(torch.amin, axis=axis), a) + + min = amin + + def amax(self, a, axis=0): + return rec_map_reduce_array_container( + partial(reduce, torch.maximum), partial(torch.amax, axis=axis), a) + + max = amax + + # }}} + + # {{{ sorting, searching, and counting + + def where(self, criterion, then, else_): + def where_inner(inner_crit, inner_then, inner_else): + import torch + if isinstance(inner_crit, torch.BoolTensor): + return torch.where(inner_crit, inner_then, inner_else) + else: + return torch.where(inner_crit != 0, inner_then, inner_else) + + return rec_multimap_array_container(where_inner, criterion, then, else_) + + # }}} diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 26535185..b4200372 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -196,6 +196,26 @@ def __str__(self): return "" +class _PytestTorchArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + pass + + @classmethod + def is_available(cls) -> bool: + try: + import torch # noqa: F401 + return True + except ImportError: + return False + + def __call__(self): + from arraycontext import TorchArrayContext + return TorchArrayContext() + + def __str__(self): + return "" + + # {{{ _PytestArrayContextFactory class _NumpyArrayContextForTests(NumpyArrayContext): @@ -224,6 +244,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "torch": _PytestTorchArrayContextFactory, "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3b95b38b..e7632d6c 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -36,6 +36,7 @@ PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, EagerJAXArrayContext, + TorchArrayContext, ArrayContainer, to_numpy, tag_axes) from arraycontext import ( # noqa: F401 @@ -46,7 +47,9 @@ _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, - _PytestNumpyArrayContextFactory) + _PytestTorchArrayContextFactory, + #_PytestNumpyArrayContextFactory + ) import logging @@ -95,7 +98,8 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, - _PytestNumpyArrayContextFactory, + _PytestTorchArrayContextFactory, + #_PytestNumpyArrayContextFactory, ]) @@ -404,6 +408,9 @@ def evaluate(_np, *_args): if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: pytest.skip(f"'{sym_name}' not supported on scalars") + if isinstance(actx, TorchArrayContext): + pytest.skip(f"{sym_name}' not supported on scalars by TorchArrayContext") + args = [randn(0, dtype)[()] for i in range(n_args)] assert_close_to_numpy(actx, evaluate, args) @@ -422,7 +429,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): assert_close_to_numpy( actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args) - for c in (42.0,) + _get_test_containers(actx): + for c in (42.0,) + _get_test_containers(actx): result = getattr(actx.np, sym_name)(c) result = actx.thaw(actx.freeze(result)) @@ -430,6 +437,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): if np.isscalar(result): assert result == 0.0 else: + print("ECG: Come back here!") assert actx.to_numpy(actx.np.all(actx.np.equal(result, 0.0))) elif sym_name == "ones_like": if np.isscalar(result):