Skip to content

Commit

Permalink
ruff: fix type import errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 13, 2024
1 parent fcd23e9 commit fa3faeb
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 207 deletions.
21 changes: 6 additions & 15 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,9 @@
THE SOFTWARE.
"""

from collections.abc import Hashable, Sequence
from functools import singledispatch
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Optional,
Protocol,
Sequence,
Tuple,
TypeAlias,
TypeVar,
)
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar

# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
Expand Down Expand Up @@ -162,7 +153,7 @@ class NotAnArrayContainerError(TypeError):


SerializationKey: TypeAlias = Hashable
SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]]
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]


@singledispatch
Expand Down Expand Up @@ -249,7 +240,7 @@ def is_array_container(ary: Any) -> bool:


@singledispatch
def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:
def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
"""Retrieves the :class:`ArrayContext` from the container, if any.
This function is not recursive, so it will only search at the root level
Expand Down Expand Up @@ -303,7 +294,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
# {{{ get_container_context_recursively

def get_container_context_recursively_opt(
ary: ArrayContainer) -> Optional[ArrayContext]:
ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.
Expand Down Expand Up @@ -337,7 +328,7 @@ def get_container_context_recursively_opt(
return actx


def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]:
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.
Expand Down
39 changes: 17 additions & 22 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"""

import enum
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
from collections.abc import Callable
from typing import Any, TypeVar
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -90,7 +91,7 @@ class _OpClass(enum.Enum):
]


def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str:
if isinstance(arg1, tuple):
arg1_entry, arg1_container = arg1
return (f"{op_str.format(arg1_entry)} "
Expand All @@ -100,20 +101,14 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:


def _format_binary_op_str(op_str: str,
arg1: Union[Tuple[str, str], str],
arg2: Union[Tuple[str, str], str]) -> str:
arg1: tuple[str, str] | str,
arg2: tuple[str, str] | str) -> str:
if isinstance(arg1, tuple) and isinstance(arg2, tuple):
import sys
if sys.version_info >= (3, 10):
strict_arg = ", strict=__debug__"
else:
strict_arg = ""

arg1_entry, arg1_container = arg1
arg2_entry, arg2_container = arg2
return (f"{op_str.format(arg1_entry, arg2_entry)} "
f"for {arg1_entry}, {arg2_entry} "
f"in zip({arg1_container}, {arg2_container}{strict_arg})")
f"in zip({arg1_container}, {arg2_container}, strict=__debug__)")

elif isinstance(arg1, tuple):
arg1_entry, arg1_container = arg1
Expand Down Expand Up @@ -160,23 +155,23 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet

def with_container_arithmetic(
*,
number_bcasts_across: Optional[bool] = None,
bcasts_across_obj_array: Optional[bool] = None,
container_types_bcast_across: Optional[Tuple[type, ...]] = None,
number_bcasts_across: bool | None = None,
bcasts_across_obj_array: bool | None = None,
container_types_bcast_across: tuple[type, ...] | None = None,
arithmetic: bool = True,
matmul: bool = False,
bitwise: bool = False,
shift: bool = False,
_cls_has_array_context_attr: Optional[bool] = None,
eq_comparison: Optional[bool] = None,
rel_comparison: Optional[bool] = None,
_cls_has_array_context_attr: bool | None = None,
eq_comparison: bool | None = None,
rel_comparison: bool | None = None,

# deprecated:
bcast_number: Optional[bool] = None,
bcast_obj_array: Optional[bool] = None,
bcast_number: bool | None = None,
bcast_obj_array: bool | None = None,
bcast_numpy_array: bool = False,
_bcast_actx_array_type: Optional[bool] = None,
bcast_container_types: Optional[Tuple[type, ...]] = None,
_bcast_actx_array_type: bool | None = None,
bcast_container_types: tuple[type, ...] | None = None,
) -> Callable[[type], type]:
"""A class decorator that implements built-in operators for array containers
by propagating the operations to the elements of the container.
Expand Down Expand Up @@ -482,7 +477,7 @@ def same_key(k1: T, k2: T) -> T:
assert k1 == k2
return k1

def tup_str(t: Tuple[str, ...]) -> str:
def tup_str(t: tuple[str, ...]) -> str:
if not t:
return "()"
else:
Expand Down
61 changes: 31 additions & 30 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@
THE SOFTWARE.
"""

from collections.abc import Callable, Iterable
from functools import partial, singledispatch, update_wrapper
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast
from typing import Any, cast
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -100,7 +101,7 @@
def _map_array_container_impl(
f: Callable[[ArrayOrContainer], ArrayOrContainer],
ary: ArrayOrContainer, *,
leaf_cls: Optional[type] = None,
leaf_cls: type | None = None,
recursive: bool = False) -> ArrayOrContainer:
"""Helper for :func:`rec_map_array_container`.
Expand Down Expand Up @@ -129,9 +130,9 @@ def rec(_ary: ArrayOrContainer) -> ArrayOrContainer:
def _multimap_array_container_impl(
f: Callable[..., Any],
*args: Any,
reduce_func: Optional[Callable[
[ArrayContainer, Iterable[Tuple[Any, Any]]], Any]] = None,
leaf_cls: Optional[type] = None,
reduce_func: (
Callable[[ArrayContainer, Iterable[tuple[Any, Any]]], Any] | None) = None,
leaf_cls: type | None = None,
recursive: bool = False) -> ArrayOrContainer:
"""Helper for :func:`rec_multimap_array_container`.
Expand Down Expand Up @@ -183,7 +184,7 @@ def rec(*_args: Any) -> Any:

# {{{ find all containers in the argument list

container_indices: List[int] = []
container_indices: list[int] = []

for i, arg in enumerate(args):
if type(arg) is leaf_cls:
Expand Down Expand Up @@ -244,7 +245,7 @@ def stringify_array_container_tree(ary: ArrayOrContainer) -> str:
:returns: a string for an ASCII tree representation of the array container,
similar to `asciitree <https://github.com/mbr/asciitree>`__.
"""
def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None:
def rec(lines: list[str], ary_: ArrayOrContainerT, level: int) -> None:
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
Expand Down Expand Up @@ -307,7 +308,7 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
def rec_map_array_container(
f: Callable[[Any], Any],
ary: ArrayOrContainer,
leaf_class: Optional[type] = None) -> ArrayOrContainer:
leaf_class: type | None = None) -> ArrayOrContainer:
r"""Applies *f* recursively to an :class:`ArrayContainer`.
For a non-recursive version see :func:`map_array_container`.
Expand All @@ -319,12 +320,12 @@ def rec_map_array_container(


def mapped_over_array_containers(
f: Optional[Callable[[ArrayOrContainer], ArrayOrContainer]] = None,
leaf_class: Optional[type] = None) -> Union[
Callable[[ArrayOrContainer], ArrayOrContainer],
Callable[
f: Callable[[ArrayOrContainer], ArrayOrContainer] | None = None,
leaf_class: type | None = None) -> (
Callable[[ArrayOrContainer], ArrayOrContainer]
| Callable[
[Callable[[Any], Any]],
Callable[[ArrayOrContainer], ArrayOrContainer]]]:
Callable[[ArrayOrContainer], ArrayOrContainer]]):
"""Decorator around :func:`rec_map_array_container`."""
def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[
[ArrayOrContainer], ArrayOrContainer]:
Expand All @@ -340,7 +341,7 @@ def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[
def rec_multimap_array_container(
f: Callable[..., Any],
*args: Any,
leaf_class: Optional[type] = None) -> Any:
leaf_class: type | None = None) -> Any:
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
For a non-recursive version see :func:`multimap_array_container`.
Expand All @@ -353,10 +354,10 @@ def rec_multimap_array_container(


def multimapped_over_array_containers(
f: Optional[Callable[..., Any]] = None,
leaf_class: Optional[type] = None) -> Union[
Callable[..., Any],
Callable[[Callable[..., Any]], Callable[..., Any]]]:
f: Callable[..., Any] | None = None,
leaf_class: type | None = None) -> (
Callable[..., Any]
| Callable[[Callable[..., Any]], Callable[..., Any]]):
"""Decorator around :func:`rec_multimap_array_container`."""
def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
# can't use functools.partial, because its result is insufficiently
Expand Down Expand Up @@ -403,7 +404,7 @@ def keyed_map_array_container(


def rec_keyed_map_array_container(
f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT],
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
ary: ArrayOrContainer) -> ArrayOrContainer:
"""
Works similarly to :func:`rec_map_array_container`, except that *f* also
Expand All @@ -412,7 +413,7 @@ def rec_keyed_map_array_container(
the current array.
"""

def rec(keys: Tuple[SerializationKey, ...],
def rec(keys: tuple[SerializationKey, ...],
_ary: ArrayOrContainerT) -> ArrayOrContainerT:
try:
iterable = serialize_container(_ary)
Expand Down Expand Up @@ -469,7 +470,7 @@ def multimap_reduce_array_container(
# NOTE: this wrapper matches the signature of `deserialize_container`
# to make plugging into `_multimap_array_container_impl` easier
def _reduce_wrapper(
ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]
ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]
) -> Array:
return reduce_func([subary for _, subary in iterable])

Expand All @@ -482,7 +483,7 @@ def rec_map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainer,
leaf_class: Optional[type] = None) -> ArrayOrContainer:
leaf_class: type | None = None) -> ArrayOrContainer:
"""Perform a map-reduce over array containers recursively.
:param reduce_func: callable used to reduce over the components of *ary*
Expand Down Expand Up @@ -540,7 +541,7 @@ def rec_multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any,
leaf_class: Optional[type] = None) -> ArrayOrContainer:
leaf_class: type | None = None) -> ArrayOrContainer:
r"""Perform a map-reduce over multiple array containers recursively.
:param reduce_func: callable used to reduce over the components of any
Expand All @@ -559,7 +560,7 @@ def rec_multimap_reduce_array_container(
# NOTE: this wrapper matches the signature of `deserialize_container`
# to make plugging into `_multimap_array_container_impl` easier
def _reduce_wrapper(
ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]) -> Any:
ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]) -> Any:
return reduce_func([subary for _, subary in iterable])

return _multimap_array_container_impl(
Expand All @@ -573,7 +574,7 @@ def _reduce_wrapper(

def freeze(
ary: ArrayOrContainerT,
actx: Optional[ArrayContext] = None) -> ArrayOrContainerT:
actx: ArrayContext | None = None) -> ArrayOrContainerT:
r"""Freezes recursively by going through all components of the
:class:`ArrayContainer` *ary*.
Expand Down Expand Up @@ -650,7 +651,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:

@singledispatch
def with_array_context(ary: ArrayOrContainerT,
actx: Optional[ArrayContext]) -> ArrayOrContainerT:
actx: ArrayContext | None) -> ArrayOrContainerT:
"""
Recursively associates *actx* to all the components of *ary*.
Expand All @@ -674,7 +675,7 @@ def with_array_context(ary: ArrayOrContainerT,

def flatten(
ary: ArrayOrContainer, actx: ArrayContext, *,
leaf_class: Optional[type] = None,
leaf_class: type | None = None,
) -> Any:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
Expand All @@ -696,7 +697,7 @@ def flatten(
"""
common_dtype = None

def _flatten(subary: ArrayOrContainer) -> List[Array]:
def _flatten(subary: ArrayOrContainer) -> list[Array]:
nonlocal common_dtype

try:
Expand Down Expand Up @@ -874,7 +875,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:


def flat_size_and_dtype(
ary: ArrayOrContainer) -> Tuple[int, Optional[np.dtype[Any]]]:
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
"""
:returns: a tuple ``(size, dtype)`` that would be the length and
:class:`numpy.dtype` of the one-dimensional array returned by
Expand Down Expand Up @@ -910,7 +911,7 @@ def _flat_size(subary: ArrayOrContainer) -> int:
# {{{ numpy conversion

def from_numpy(
ary: Union[np.ndarray, ScalarLike],
ary: np.ndarray | ScalarLike,
actx: ArrayContext) -> ArrayOrContainerOrScalar:
"""Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
to the base array type of :class:`~arraycontext.ArrayContext`.
Expand Down
Loading

0 comments on commit fa3faeb

Please sign in to comment.