Skip to content

Commit

Permalink
Implement actx.np.zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Aug 5, 2024
1 parent 0e3f321 commit 3f8d0f6
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 19 deletions.
18 changes: 8 additions & 10 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
TypeVar,
Union,
)
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -249,10 +250,6 @@ class ArrayContext(ABC):
.. versionadded:: 2020.2
.. automethod:: empty
.. automethod:: zeros
.. automethod:: empty_like
.. automethod:: zeros_like
.. automethod:: from_numpy
.. automethod:: to_numpy
.. automethod:: call_loopy
Expand Down Expand Up @@ -293,9 +290,9 @@ class ArrayContext(ABC):
def __init__(self) -> None:
self.np = self._get_fake_numpy_namespace()

@abstractmethod
def _get_fake_numpy_namespace(self) -> Any:
from .fake_numpy import BaseFakeNumpyNamespace
return BaseFakeNumpyNamespace(self)
...

def __hash__(self) -> int:
raise TypeError(f"unhashable type: '{type(self).__name__}'")
Expand All @@ -306,22 +303,23 @@ def empty(self,
dtype: "np.dtype[Any]") -> Array:
pass

@abstractmethod
def zeros(self,
shape: Union[int, Tuple[int, ...]],
dtype: "np.dtype[Any]") -> Array:
pass
warn(f"{type(self).__name__}.zeros is deprecated and will stop "
"working in 2025. Use actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)

return self.np.zeros(shape, dtype)

def empty_like(self, ary: Array) -> Array:
from warnings import warn
warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
"working in 2023. Prefer actx.np.zeros_like instead.",
DeprecationWarning, stacklevel=2)

return self.empty(shape=ary.shape, dtype=ary.dtype)

def zeros_like(self, ary: Array) -> Array:
from warnings import warn
warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
"working in 2023. Use actx.np.zeros_like instead.",
DeprecationWarning, stacklevel=2)
Expand Down
11 changes: 10 additions & 1 deletion arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


import operator
from abc import ABC, abstractmethod
from typing import Any

import numpy as np
Expand All @@ -34,7 +35,7 @@

# {{{ BaseFakeNumpyNamespace

class BaseFakeNumpyNamespace:
class BaseFakeNumpyNamespace(ABC):
def __init__(self, array_context):
self._array_context = array_context
self.linalg = self._get_fake_numpy_linalg_namespace()
Expand Down Expand Up @@ -95,6 +96,14 @@ def _get_fake_numpy_linalg_namespace(self):
# "interp",
})

@abstractmethod
def zeros(self, shape, dtype):
...

@abstractmethod
def zeros_like(self, ary):
...

def conjugate(self, x):
# NOTE: conjugate distributes over object arrays, but it looks for a
# `conjugate` ufunc, while some implementations only have the shorter
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _wrapper(ary):
def empty(self, shape, dtype):
from warnings import warn
warn(f"{type(self).__name__}.empty is deprecated and will stop "
"working in 2023. Prefer actx.zeros instead.",
"working in 2023. Prefer actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)

import jax.numpy as jnp
Expand Down
3 changes: 3 additions & 0 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __getattr__(self, name):

# {{{ array creation routines

def zeros(self, shape, dtype):
return jnp.zeros(shape=shape, dtype=dtype)

def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _wrapper(ary):
def empty(self, shape, dtype):
from warnings import warn
warn(f"{type(self).__name__}.empty is deprecated and will stop "
"working in 2023. Prefer actx.zeros instead.",
"working in 2023. Prefer actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)

import arraycontext.impl.pyopencl.taggable_cl_array as tga
Expand Down
6 changes: 6 additions & 0 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
rec_multimap_reduce_array_container,
)
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace


Expand All @@ -60,6 +61,11 @@ def _get_fake_numpy_linalg_namespace(self):

# {{{ array creation routines

def zeros(self, shape, dtype) -> TaggableCLArray:
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.zeros(self._array_context.queue, shape, dtype,
allocator=self._array_context.allocator)

def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
Expand Down
3 changes: 3 additions & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __getattr__(self, name):

# {{{ array creation routines

def zeros(self, shape, dtype):
return pt.zeros(shape, dtype)

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(
Expand Down
12 changes: 6 additions & 6 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def test_leaf_array_type_broadcasting(actx_factory):
# test support for https://github.com/inducer/arraycontext/issues/49
actx = actx_factory()

foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, )))
foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )))
bar = foo + 4
baz = foo + actx.from_numpy(4*np.ones((3, )))
qux = actx.from_numpy(4*np.ones((3, ))) + foo
Expand Down Expand Up @@ -1510,7 +1510,7 @@ def _twice(x):

actx = actx_factory()
ones = actx.thaw(actx.freeze(
actx.zeros(shape=(10, 4), dtype=np.float64) + 1
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
))
np.testing.assert_allclose(actx.to_numpy(_twice(ones)),
actx.to_numpy(actx.compile(_twice)(ones)))
Expand Down Expand Up @@ -1573,7 +1573,7 @@ def test_taggable_cl_array_tags(actx_factory):
def test_to_numpy_on_frozen_arrays(actx_factory):
# See https://github.com/inducer/arraycontext/issues/159
actx = actx_factory()
u = actx.freeze(actx.zeros(10, dtype="float64")+1)
u = actx.freeze(actx.np.zeros(10, dtype="float64")+1)
np.testing.assert_allclose(actx.to_numpy(u), 1)
np.testing.assert_allclose(actx.to_numpy(u), 1)

Expand All @@ -1592,7 +1592,7 @@ class ExampleTag(Tag):
ary = tag_axes(actx, {0: ExampleTag()},
actx.tag(
ExampleTag(),
actx.zeros((20, 20), dtype=np.float64)))
actx.np.zeros((20, 20), dtype=np.float64)))

assert ary.tags_of_type(ExampleTag)
assert ary.axes[0].tags_of_type(ExampleTag)
Expand All @@ -1606,11 +1606,11 @@ def test_compile_anonymous_function(actx_factory):
actx = actx_factory()
f = actx.compile(lambda x: 2*x+40)
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)
f = actx.compile(partial(lambda x: 2*x+40))
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)


Expand Down

0 comments on commit 3f8d0f6

Please sign in to comment.