Skip to content

Commit

Permalink
{to,from}_numpy: Use overloads for more precise type info
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 27, 2024
1 parent 7dd66db commit 32f010c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
18 changes: 17 additions & 1 deletion arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@

from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -320,6 +320,14 @@ def zeros(self,

return self.np.zeros(shape, dtype)

@overload
def from_numpy(self, array: np.ndarray) -> Array:
...

@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

@abstractmethod
def from_numpy(self,
array: NumpyOrContainerOrScalar
Expand All @@ -333,6 +341,14 @@ def from_numpy(self,
intact.
"""

@overload
def to_numpy(self, array: np.ndarray) -> Array:
...

@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

@abstractmethod
def to_numpy(self,
array: ArrayOrContainerOrScalar
Expand Down
27 changes: 22 additions & 5 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations


__doc__ = """
"""
.. currentmodule:: arraycontext
A :mod:`numpy`-based array context.
.. autoclass:: NumpyArrayContext
"""

from __future__ import annotations


__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
Expand All @@ -33,7 +33,7 @@
THE SOFTWARE.
"""

from typing import Any
from typing import Any, overload

import numpy as np

Expand All @@ -46,6 +46,7 @@
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)
Expand Down Expand Up @@ -84,11 +85,27 @@ def _get_fake_numpy_namespace(self):
def clone(self):
return type(self)()

@overload
def from_numpy(self, array: np.ndarray) -> Array:
...

@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
return array

@overload
def to_numpy(self, array: Array) -> np.ndarray:
...

@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
Expand Down
5 changes: 1 addition & 4 deletions arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:

# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
# type-ignore: discussed at
# https://github.com/inducer/arraycontext/pull/289#discussion_r1855523967
# possibly related: https://github.com/python/mypy/issues/17375
return DataWrapper( # type: ignore[call-arg]
return DataWrapper(
data=new_dw.data,
shape=expr.shape,
axes=expr.axes,
Expand Down

0 comments on commit 32f010c

Please sign in to comment.