From 958c769d0fac477240f8742e23d1c58638f0f720 Mon Sep 17 00:00:00 2001 From: Esteban Date: Tue, 20 Sep 2022 19:40:43 -0400 Subject: [PATCH] Added new Torch array context --- arraycontext/__init__.py | 3 +- arraycontext/impl/jax/fake_numpy.py | 4 +- arraycontext/impl/torch/__init__.py | 174 ++++++++++++++++++++++ arraycontext/impl/torch/fake_numpy.py | 200 ++++++++++++++++++++++++++ arraycontext/pytest.py | 21 +++ test/test_arraycontext.py | 14 +- 6 files changed, 410 insertions(+), 6 deletions(-) create mode 100644 arraycontext/impl/torch/__init__.py create mode 100644 arraycontext/impl/torch/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index cf3c961e..d25c070b 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -77,6 +77,7 @@ PytatoJAXArrayContext) from .impl.jax import EagerJAXArrayContext from .impl.numpy import NumpyArrayContext +from .impl.torch import TorchArrayContext from .pytest import ( PytestArrayContextFactory, @@ -124,7 +125,7 @@ "PytatoJAXArrayContext", "EagerJAXArrayContext", - "NumpyArrayContext", + "NumpyArrayContext", "TorchArrayContext", "make_loopy_program", diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 37c99b4a..03bb3ade 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -68,7 +68,7 @@ def _full_like(subary): # }}} - # {{{ array manipulation routies + # {{{ array manipulation routines def reshape(self, a, newshape, order="C"): return rec_map_array_container( @@ -133,7 +133,7 @@ def all(self, a): def any(self, a): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) - + def array_equal(self, a, b): actx = self._array_context diff --git a/arraycontext/impl/torch/__init__.py b/arraycontext/impl/torch/__init__.py new file mode 100644 index 00000000..362523ba --- /dev/null +++ b/arraycontext/impl/torch/__init__.py @@ -0,0 +1,174 @@ +""" +.. currentmodule:: arraycontext +.. autoclass:: TorchArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Callable, Optional, Tuple + +import numpy as np +from typing import Union, Sequence, Dict +from pytools.tag import Tag, ToTagSetConvertible +from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike +from arraycontext.container.traversal import (with_array_context, + rec_map_array_container) + + +class TorchArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :class:`torch.Tensor` instances for its base array class. + + + .. automethod:: __init__ + """ + def __init__(self) -> None: + super().__init__() + + from torch import Tensor + self.array_types = (Tensor, ) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import TorchFakeNumpyNamespace + return TorchFakeNumpyNamespace(self) + + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: + if allowed_types is None: + allowed_types = self.array_types + + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + # {{{ ArrayContext interface + + def empty(self, shape, dtype): + import torch + return torch.empty(shape, dtype) + + def zeros(self, shape, dtype): + import torch + return torch.zeros(shape, dtype=dtype) + + def empty_like(self, ary): + def _empty_like(array): + return self.empty(array.shape, array.dtype) + + return self._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + def _zeros_like(array): + return self.zeros(array.shape, array.dtype) + + return self._rec_map_container(_zeros_like, ary, default_scalar=0) + + def from_numpy(self, array: np.ndarray): + def _from_numpy(ary): + import torch + return torch.from_numpy(array) + + return with_array_context( + self._rec_map_container(_from_numpy, array, allowed_types=(np.ndarray,)), + actx=self) + + def to_numpy(self, array): + def _to_numpy(ary): + import torch + if array.dtype == torch.complex64 or array.dtype == torch.complex128: + return ary.resolve_conj().numpy() + else: + return ary.detach().numpy() + + return with_array_context( + self._rec_map_container(_to_numpy, array), + actx=None) + + def freeze(self, array): + def _freeze(ary): + return ary + return with_array_context(self._rec_map_container(_freeze, array), actx=None) + + def thaw(self, array): + return with_array_context(array, actx=self) + + def tag(self, tags: ToTagSetConvertible, array): + # Sorry, not capable. + return array + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + # Sorry, not capable. + return array + + def call_loopy(self, t_unit, **kwargs): + raise NotImplementedError( + "Calling loopy on Torch arrays is not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array operations using" + " ArrayContext.np.") + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import torch + if arg_names is not None: + from warnings import warn + warn("'arg_names' don't bear any significance in " + f"{type(self).__name__}.", stacklevel=2) + + return torch.einsum(spec, *args) + + def clone(self): + return type(self)() + + # }}} + + # {{{ properties + + @property + def permits_inplace_modification(self): + return False + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True + + # }}} diff --git a/arraycontext/impl/torch/fake_numpy.py b/arraycontext/impl/torch/fake_numpy.py new file mode 100644 index 00000000..9dd1139c --- /dev/null +++ b/arraycontext/impl/torch/fake_numpy.py @@ -0,0 +1,200 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np +import torch + +from arraycontext.fake_numpy import ( + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, + ) +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + rec_map_array_container, + rec_multimap_array_container, + multimap_reduce_array_container, + rec_map_reduce_array_container, + rec_multimap_reduce_array_container, + ) + + +class TorchFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +class TorchFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`~arraycontext.TorchArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return TorchFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + return partial(rec_multimap_array_container, getattr(torch, name)) + + # NOTE: the order of these follows the order in numpy docs + # NOTE: when adding a function here, also add it to `array_context.rst` docs! + + # {{{ array creation routines + + def ones_like(self, ary): + return self.full_like(ary, 1) + + def full_like(self, ary, fill_value): + def _full_like(subary): + return torch.full_like(subary, fill_value) + + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) + + # }}} + + # {{{ array manipulation routines + + def reshape(self, a, newshape, order="C"): + """ + .. warning:: + + Since :func:`torch.reshape` does not support orders `A`` and + ``K``, in such cases we fallback to using ``order = C``. + """ + if order in "AK": + from warnings import warn + warn(f"reshape with order='{order}' nor supported by Torch," + " using order=C.") + + return rec_map_array_container( + lambda ary: torch.reshape(ary, newshape), a + ) + + def ravel(self, a, order="C"): + """ + .. warning:: + + Since :func:`torch.reshape` does not support orders `A`` and + ``K``, in such cases we fallback to using ``order = C``. + """ + if order in "AK": + from warnings import warn + warn(f"reshape with order='{order}' nor supported by Torch," + " using order=C.") + + return rec_map_array_container( + lambda ary: torch.ravel(ary), a + ) + + def transpose(self, a, dim0=0, dim1=1): + return rec_multimap_array_container(torch.transpose, a, dim0, dim1) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(torch.broadcast_to, shape=shape), array) + + def concatenate(self, arrays, axis=0): + return rec_multimap_array_container(torch.cat, arrays, axis) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: torch.stack(tensors=args, dim=axis), *arrays) + + # {{{ linear algebra + + # }}} + + # {{{ logic functions + + def all(self, a): + return rec_map_reduce_array_container( + partial(reduce, torch.logical_and), torch.all, a) + + def any(self, a): + return rec_map_reduce_array_container( + partial(reduce, torch.logical_or), torch.any, a) + + def array_equal(self, a, b): + actx = self._array_context + + true = actx.from_numpy(np.int8(True)) + false = actx.from_numpy(np.int8(False)) + + def rec_equal(x, y): + if type(x) != type(y): + return false + + try: + iterable = zip(serialize_container(x), serialize_container(y)) + except NotAnArrayContainerError: + if x.shape != y.shape: + return false + else: + return torch.all(torch.equal(x, y)) + else: + return reduce( + torch.logical_and, + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable], + true) + + return rec_equal(a, b) + + def equal(self, a, b): + # ECG: Really? + return a == b + + # }}} + + # {{{ mathematical functions + + def sum(self, a, axis=0, dtype=None): + return rec_map_reduce_array_container( + sum, + partial(torch.sum, axis=axis, dtype=dtype), + a) + + def amin(self, a, axis=0): + return rec_map_reduce_array_container( + partial(reduce, torch.minimum), partial(torch.amin, axis=axis), a) + + min = amin + + def amax(self, a, axis=0): + return rec_map_reduce_array_container( + partial(reduce, torch.maximum), partial(torch.amax, axis=axis), a) + + max = amax + + # }}} + + # {{{ sorting, searching, and counting + + def where(self, criterion, then, else_): + def where_inner(inner_crit, inner_then, inner_else): + import torch + if isinstance(inner_crit, torch.BoolTensor): + return torch.where(inner_crit, inner_then, inner_else) + else: + return torch.where(inner_crit != 0, inner_then, inner_else) + + return rec_multimap_array_container(where_inner, criterion, then, else_) + + # }}} diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 26535185..b4200372 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -196,6 +196,26 @@ def __str__(self): return "" +class _PytestTorchArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + pass + + @classmethod + def is_available(cls) -> bool: + try: + import torch # noqa: F401 + return True + except ImportError: + return False + + def __call__(self): + from arraycontext import TorchArrayContext + return TorchArrayContext() + + def __str__(self): + return "" + + # {{{ _PytestArrayContextFactory class _NumpyArrayContextForTests(NumpyArrayContext): @@ -224,6 +244,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "torch": _PytestTorchArrayContextFactory, "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3b95b38b..e7632d6c 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -36,6 +36,7 @@ PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, EagerJAXArrayContext, + TorchArrayContext, ArrayContainer, to_numpy, tag_axes) from arraycontext import ( # noqa: F401 @@ -46,7 +47,9 @@ _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, - _PytestNumpyArrayContextFactory) + _PytestTorchArrayContextFactory, + #_PytestNumpyArrayContextFactory + ) import logging @@ -95,7 +98,8 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, - _PytestNumpyArrayContextFactory, + _PytestTorchArrayContextFactory, + #_PytestNumpyArrayContextFactory, ]) @@ -404,6 +408,9 @@ def evaluate(_np, *_args): if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: pytest.skip(f"'{sym_name}' not supported on scalars") + if isinstance(actx, TorchArrayContext): + pytest.skip(f"{sym_name}' not supported on scalars by TorchArrayContext") + args = [randn(0, dtype)[()] for i in range(n_args)] assert_close_to_numpy(actx, evaluate, args) @@ -422,7 +429,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): assert_close_to_numpy( actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args) - for c in (42.0,) + _get_test_containers(actx): + for c in (42.0,) + _get_test_containers(actx): result = getattr(actx.np, sym_name)(c) result = actx.thaw(actx.freeze(result)) @@ -430,6 +437,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): if np.isscalar(result): assert result == 0.0 else: + print("ECG: Come back here!") assert actx.to_numpy(actx.np.all(actx.np.equal(result, 0.0))) elif sym_name == "ones_like": if np.isscalar(result):