diff --git a/arraycontext/context.py b/arraycontext/context.py index 0d0595c3..e8c40896 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -156,7 +156,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload from warnings import warn import numpy as np @@ -320,6 +320,14 @@ def zeros(self, return self.np.zeros(shape, dtype) + @overload + def from_numpy(self, array: np.ndarray) -> Array: + ... + + @overload + def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + @abstractmethod def from_numpy(self, array: NumpyOrContainerOrScalar @@ -333,6 +341,14 @@ def from_numpy(self, intact. """ + @overload + def to_numpy(self, array: np.ndarray) -> Array: + ... + + @overload + def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + @abstractmethod def to_numpy(self, array: ArrayOrContainerOrScalar diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index c2f884a6..f9d6c541 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -1,7 +1,4 @@ -from __future__ import annotations - - -__doc__ = """ +""" .. currentmodule:: arraycontext A :mod:`numpy`-based array context. @@ -9,6 +6,9 @@ .. autoclass:: NumpyArrayContext """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ @@ -33,7 +33,7 @@ THE SOFTWARE. """ -from typing import Any +from typing import Any, overload import numpy as np @@ -46,6 +46,7 @@ ArrayContext, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, + ContainerOrScalarT, NumpyOrContainerOrScalar, UntransformedCodeWarning, ) @@ -84,11 +85,27 @@ def _get_fake_numpy_namespace(self): def clone(self): return type(self)() + @overload + def from_numpy(self, array: np.ndarray) -> Array: + ... + + @overload + def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + def from_numpy(self, array: NumpyOrContainerOrScalar ) -> ArrayOrContainerOrScalar: return array + @overload + def to_numpy(self, array: Array) -> np.ndarray: + ... + + @overload + def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + def to_numpy(self, array: ArrayOrContainerOrScalar ) -> NumpyOrContainerOrScalar: diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index c031e29b..6441527a 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -163,10 +163,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: # https://github.com/pylint-dev/pylint/issues/3893 # pylint: disable=unexpected-keyword-arg - # type-ignore: discussed at - # https://github.com/inducer/arraycontext/pull/289#discussion_r1855523967 - # possibly related: https://github.com/python/mypy/issues/17375 - return DataWrapper( # type: ignore[call-arg] + return DataWrapper( data=new_dw.data, shape=expr.shape, axes=expr.axes,