From f7b4c79b2f8d820d9a859603ac959e0e49f6a393 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 5 Aug 2024 12:42:46 -0500 Subject: [PATCH] Implement actx.np.zeros --- arraycontext/context.py | 18 ++++++++---------- arraycontext/fake_numpy.py | 11 ++++++++++- arraycontext/impl/jax/__init__.py | 2 +- arraycontext/impl/jax/fake_numpy.py | 3 +++ arraycontext/impl/pyopencl/__init__.py | 2 +- arraycontext/impl/pyopencl/fake_numpy.py | 6 ++++++ arraycontext/impl/pytato/fake_numpy.py | 3 +++ test/test_arraycontext.py | 12 ++++++------ 8 files changed, 38 insertions(+), 19 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 38c52ddb..8b42bca7 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -171,6 +171,7 @@ TypeVar, Union, ) +from warnings import warn import numpy as np @@ -249,10 +250,6 @@ class ArrayContext(ABC): .. versionadded:: 2020.2 - .. automethod:: empty - .. automethod:: zeros - .. automethod:: empty_like - .. automethod:: zeros_like .. automethod:: from_numpy .. automethod:: to_numpy .. automethod:: call_loopy @@ -293,9 +290,9 @@ class ArrayContext(ABC): def __init__(self) -> None: self.np = self._get_fake_numpy_namespace() + @abstractmethod def _get_fake_numpy_namespace(self) -> Any: - from .fake_numpy import BaseFakeNumpyNamespace - return BaseFakeNumpyNamespace(self) + ... def __hash__(self) -> int: raise TypeError(f"unhashable type: '{type(self).__name__}'") @@ -306,14 +303,16 @@ def empty(self, dtype: "np.dtype[Any]") -> Array: pass - @abstractmethod def zeros(self, shape: Union[int, Tuple[int, ...]], dtype: "np.dtype[Any]") -> Array: - pass + warn(f"{type(self).__name__}.zeros is deprecated and will stop " + "working in 2025. Use actx.np.zeros instead.", + DeprecationWarning, stacklevel=2) + + return self.np.zeros(shape, dtype) def empty_like(self, ary: Array) -> Array: - from warnings import warn warn(f"{type(self).__name__}.empty_like is deprecated and will stop " "working in 2023. Prefer actx.np.zeros_like instead.", DeprecationWarning, stacklevel=2) @@ -321,7 +320,6 @@ def empty_like(self, ary: Array) -> Array: return self.empty(shape=ary.shape, dtype=ary.dtype) def zeros_like(self, ary: Array) -> Array: - from warnings import warn warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " "working in 2023. Use actx.np.zeros_like instead.", DeprecationWarning, stacklevel=2) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index e31bae7d..8473cc42 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,6 +24,7 @@ import operator +from abc import ABC, abstractmethod from typing import Any import numpy as np @@ -34,7 +35,7 @@ # {{{ BaseFakeNumpyNamespace -class BaseFakeNumpyNamespace: +class BaseFakeNumpyNamespace(ABC): def __init__(self, array_context): self._array_context = array_context self.linalg = self._get_fake_numpy_linalg_namespace() @@ -95,6 +96,14 @@ def _get_fake_numpy_linalg_namespace(self): # "interp", }) + @abstractmethod + def zeros(self, shape, dtype): + ... + + @abstractmethod + def zeros_like(self, ary): + ... + def conjugate(self, x): # NOTE: conjugate distributes over object arrays, but it looks for a # `conjugate` ufunc, while some implementations only have the shorter diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index c52b24a6..03045419 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -90,7 +90,7 @@ def _wrapper(ary): def empty(self, shape, dtype): from warnings import warn warn(f"{type(self).__name__}.empty is deprecated and will stop " - "working in 2023. Prefer actx.zeros instead.", + "working in 2023. Prefer actx.np.zeros instead.", DeprecationWarning, stacklevel=2) import jax.numpy as jnp diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index afe67283..3fc5f2e6 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -56,6 +56,9 @@ def __getattr__(self, name): # {{{ array creation routines + def zeros(self, shape, dtype): + return jnp.zeros(shape=shape, dtype=dtype) + def empty_like(self, ary): from warnings import warn warn(f"{type(self._array_context).__name__}.np.empty_like is " diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 990a4223..9be77a44 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -201,7 +201,7 @@ def _wrapper(ary): def empty(self, shape, dtype): from warnings import warn warn(f"{type(self).__name__}.empty is deprecated and will stop " - "working in 2023. Prefer actx.zeros instead.", + "working in 2023. Prefer actx.np.zeros instead.", DeprecationWarning, stacklevel=2) import arraycontext.impl.pyopencl.taggable_cl_array as tga diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 2583bfa2..59be99e8 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -39,6 +39,7 @@ rec_multimap_reduce_array_container, ) from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace +from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -60,6 +61,11 @@ def _get_fake_numpy_linalg_namespace(self): # {{{ array creation routines + def zeros(self, shape, dtype) -> TaggableCLArray: + import arraycontext.impl.pyopencl.taggable_cl_array as tga + return tga.zeros(self._array_context.queue, shape, dtype, + allocator=self._array_context.allocator) + def empty_like(self, ary): from warnings import warn warn(f"{type(self._array_context).__name__}.np.empty_like is " diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 9c41b523..d3d018d6 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -84,6 +84,9 @@ def __getattr__(self, name): # {{{ array creation routines + def zeros(self, shape, dtype): + return pt.zeros(shape, dtype) + def zeros_like(self, ary): def _zeros_like(array): return self._array_context.zeros( diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index fb16b872..3f06156b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1367,7 +1367,7 @@ def test_leaf_array_type_broadcasting(actx_factory): # test support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() - foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) + foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))) bar = foo + 4 baz = foo + actx.from_numpy(4*np.ones((3, ))) qux = actx.from_numpy(4*np.ones((3, ))) + foo @@ -1510,7 +1510,7 @@ def _twice(x): actx = actx_factory() ones = actx.thaw(actx.freeze( - actx.zeros(shape=(10, 4), dtype=np.float64) + 1 + actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1 )) np.testing.assert_allclose(actx.to_numpy(_twice(ones)), actx.to_numpy(actx.compile(_twice)(ones))) @@ -1573,7 +1573,7 @@ def test_taggable_cl_array_tags(actx_factory): def test_to_numpy_on_frozen_arrays(actx_factory): # See https://github.com/inducer/arraycontext/issues/159 actx = actx_factory() - u = actx.freeze(actx.zeros(10, dtype="float64")+1) + u = actx.freeze(actx.np.zeros(10, dtype="float64")+1) np.testing.assert_allclose(actx.to_numpy(u), 1) np.testing.assert_allclose(actx.to_numpy(u), 1) @@ -1592,7 +1592,7 @@ class ExampleTag(Tag): ary = tag_axes(actx, {0: ExampleTag()}, actx.tag( ExampleTag(), - actx.zeros((20, 20), dtype=np.float64))) + actx.np.zeros((20, 20), dtype=np.float64))) assert ary.tags_of_type(ExampleTag) assert ary.axes[0].tags_of_type(ExampleTag) @@ -1606,11 +1606,11 @@ def test_compile_anonymous_function(actx_factory): actx = actx_factory() f = actx.compile(lambda x: 2*x+40) np.testing.assert_allclose( - actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))), + actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), 42) f = actx.compile(partial(lambda x: 2*x+40)) np.testing.assert_allclose( - actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))), + actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), 42)