diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 655a3e64..1ea85654 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -79,15 +79,12 @@ THE SOFTWARE. """ +from collections.abc import Hashable, Sequence from functools import singledispatch from typing import ( TYPE_CHECKING, Any, - Hashable, - Optional, Protocol, - Sequence, - Tuple, TypeAlias, TypeVar, ) @@ -162,7 +159,7 @@ class NotAnArrayContainerError(TypeError): SerializationKey: TypeAlias = Hashable -SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]] +SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]] @singledispatch @@ -249,7 +246,7 @@ def is_array_container(ary: Any) -> bool: @singledispatch -def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: +def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None: """Retrieves the :class:`ArrayContext` from the container, if any. This function is not recursive, so it will only search at the root level @@ -303,7 +300,7 @@ def _deserialize_ndarray_container( # type: ignore[misc] # {{{ get_container_context_recursively def get_container_context_recursively_opt( - ary: ArrayContainer) -> Optional[ArrayContext]: + ary: ArrayContainer) -> ArrayContext | None: """Walks the :class:`ArrayContainer` hierarchy to find an :class:`ArrayContext` associated with it. @@ -337,7 +334,7 @@ def get_container_context_recursively_opt( return actx -def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]: +def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None: """Walks the :class:`ArrayContainer` hierarchy to find an :class:`ArrayContext` associated with it. diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 44de96c5..9e47ae28 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -34,7 +34,8 @@ """ import enum -from typing import Any, Callable, Optional, Tuple, TypeVar, Union +from collections.abc import Callable +from typing import Any, TypeVar from warnings import warn import numpy as np @@ -90,7 +91,7 @@ class _OpClass(enum.Enum): ] -def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: +def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str: if isinstance(arg1, tuple): arg1_entry, arg1_container = arg1 return (f"{op_str.format(arg1_entry)} " @@ -100,14 +101,10 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_binary_op_str(op_str: str, - arg1: Union[Tuple[str, str], str], - arg2: Union[Tuple[str, str], str]) -> str: + arg1: tuple[str, str] | str, + arg2: tuple[str, str] | str) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): - import sys - if sys.version_info >= (3, 10): - strict_arg = ", strict=__debug__" - else: - strict_arg = "" + strict_arg = ", strict=__debug__" arg1_entry, arg1_container = arg1 arg2_entry, arg2_container = arg2 @@ -160,23 +157,23 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet def with_container_arithmetic( *, - number_bcasts_across: Optional[bool] = None, - bcasts_across_obj_array: Optional[bool] = None, - container_types_bcast_across: Optional[Tuple[type, ...]] = None, + number_bcasts_across: bool | None = None, + bcasts_across_obj_array: bool | None = None, + container_types_bcast_across: tuple[type, ...] | None = None, arithmetic: bool = True, matmul: bool = False, bitwise: bool = False, shift: bool = False, - _cls_has_array_context_attr: Optional[bool] = None, - eq_comparison: Optional[bool] = None, - rel_comparison: Optional[bool] = None, + _cls_has_array_context_attr: bool | None = None, + eq_comparison: bool | None = None, + rel_comparison: bool | None = None, # deprecated: - bcast_number: Optional[bool] = None, - bcast_obj_array: Optional[bool] = None, + bcast_number: bool | None = None, + bcast_obj_array: bool | None = None, bcast_numpy_array: bool = False, - _bcast_actx_array_type: Optional[bool] = None, - bcast_container_types: Optional[Tuple[type, ...]] = None, + _bcast_actx_array_type: bool | None = None, + bcast_container_types: tuple[type, ...] | None = None, ) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. @@ -482,7 +479,7 @@ def same_key(k1: T, k2: T) -> T: assert k1 == k2 return k1 - def tup_str(t: Tuple[str, ...]) -> str: + def tup_str(t: tuple[str, ...]) -> str: if not t: return "()" else: diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 492f0c92..ec4c37f4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -31,7 +31,7 @@ """ from dataclasses import Field, fields, is_dataclass -from typing import Tuple, Union, get_args, get_origin +from typing import Union, get_args, get_origin from arraycontext.container import is_array_container_type @@ -100,7 +100,7 @@ def is_array_field(f: Field) -> bool: _BaseGenericAlias, _SpecialForm, ) - if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)): + if isinstance(f.type, _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( f"Typing annotation not supported on field '{f.name}': " @@ -125,8 +125,8 @@ def is_array_field(f: Field) -> bool: def inject_dataclass_serialization( cls: type, - array_fields: Tuple[Field, ...], - non_array_fields: Tuple[Field, ...]) -> type: + array_fields: tuple[Field, ...], + non_array_fields: tuple[Field, ...]) -> type: """Implements :func:`~arraycontext.serialize_container` and :func:`~arraycontext.deserialize_container` for the given dataclass *cls*. diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index f7a99216..ebea5fbf 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -70,8 +70,9 @@ THE SOFTWARE. """ +from collections.abc import Callable, Iterable from functools import partial, singledispatch, update_wrapper -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, cast from warnings import warn import numpy as np @@ -100,7 +101,7 @@ def _map_array_container_impl( f: Callable[[ArrayOrContainer], ArrayOrContainer], ary: ArrayOrContainer, *, - leaf_cls: Optional[type] = None, + leaf_cls: type | None = None, recursive: bool = False) -> ArrayOrContainer: """Helper for :func:`rec_map_array_container`. @@ -129,9 +130,10 @@ def rec(_ary: ArrayOrContainer) -> ArrayOrContainer: def _multimap_array_container_impl( f: Callable[..., Any], *args: Any, - reduce_func: Optional[Callable[ - [ArrayContainer, Iterable[Tuple[Any, Any]]], Any]] = None, - leaf_cls: Optional[type] = None, + reduce_func: + Callable[[ArrayContainer, Iterable[tuple[Any, Any]]], Any] | None + = None, + leaf_cls: type | None = None, recursive: bool = False) -> ArrayOrContainer: """Helper for :func:`rec_multimap_array_container`. @@ -184,7 +186,7 @@ def rec(*_args: Any) -> Any: # {{{ find all containers in the argument list - container_indices: List[int] = [] + container_indices: list[int] = [] for i, arg in enumerate(args): if type(arg) is leaf_cls: @@ -245,7 +247,7 @@ def stringify_array_container_tree(ary: ArrayOrContainer) -> str: :returns: a string for an ASCII tree representation of the array container, similar to `asciitree `__. """ - def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None: + def rec(lines: list[str], ary_: ArrayOrContainerT, level: int) -> None: try: iterable = serialize_container(ary_) except NotAnArrayContainerError: @@ -308,7 +310,7 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: def rec_map_array_container( f: Callable[[Any], Any], ary: ArrayOrContainer, - leaf_class: Optional[type] = None) -> ArrayOrContainer: + leaf_class: type | None = None) -> ArrayOrContainer: r"""Applies *f* recursively to an :class:`ArrayContainer`. For a non-recursive version see :func:`map_array_container`. @@ -320,12 +322,13 @@ def rec_map_array_container( def mapped_over_array_containers( - f: Optional[Callable[[ArrayOrContainer], ArrayOrContainer]] = None, - leaf_class: Optional[type] = None) -> Union[ - Callable[[ArrayOrContainer], ArrayOrContainer], - Callable[ - [Callable[[Any], Any]], - Callable[[ArrayOrContainer], ArrayOrContainer]]]: + f: Callable[[ArrayOrContainer], ArrayOrContainer] | None = None, + leaf_class: type | None = None + ) -> ( + Callable[[ArrayOrContainer], ArrayOrContainer] + | Callable[[Callable[[Any], Any]], + Callable[[ArrayOrContainer], ArrayOrContainer]] + ): """Decorator around :func:`rec_map_array_container`.""" def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[ [ArrayOrContainer], ArrayOrContainer]: @@ -341,7 +344,7 @@ def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[ def rec_multimap_array_container( f: Callable[..., Any], *args: Any, - leaf_class: Optional[type] = None) -> Any: + leaf_class: type | None = None) -> Any: r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s. For a non-recursive version see :func:`multimap_array_container`. @@ -354,10 +357,9 @@ def rec_multimap_array_container( def multimapped_over_array_containers( - f: Optional[Callable[..., Any]] = None, - leaf_class: Optional[type] = None) -> Union[ - Callable[..., Any], - Callable[[Callable[..., Any]], Callable[..., Any]]]: + f: Callable[..., Any] | None = None, + leaf_class: type | None = None + ) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorator around :func:`rec_multimap_array_container`.""" def decorator(g: Callable[..., Any]) -> Callable[..., Any]: # can't use functools.partial, because its result is insufficiently @@ -404,7 +406,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT], + f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -413,7 +415,7 @@ def rec_keyed_map_array_container( the current array. """ - def rec(keys: Tuple[SerializationKey, ...], + def rec(keys: tuple[SerializationKey, ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) @@ -470,7 +472,7 @@ def multimap_reduce_array_container( # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier def _reduce_wrapper( - ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]] + ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]] ) -> Array: return reduce_func([subary for _, subary in iterable]) @@ -483,7 +485,7 @@ def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], ary: ArrayOrContainer, - leaf_class: Optional[type] = None) -> ArrayOrContainer: + leaf_class: type | None = None) -> ArrayOrContainer: """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of *ary* @@ -541,7 +543,7 @@ def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], *args: Any, - leaf_class: Optional[type] = None) -> ArrayOrContainer: + leaf_class: type | None = None) -> ArrayOrContainer: r"""Perform a map-reduce over multiple array containers recursively. :param reduce_func: callable used to reduce over the components of any @@ -560,7 +562,7 @@ def rec_multimap_reduce_array_container( # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier def _reduce_wrapper( - ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]) -> Any: + ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]) -> Any: return reduce_func([subary for _, subary in iterable]) return _multimap_array_container_impl( @@ -574,7 +576,7 @@ def _reduce_wrapper( def freeze( ary: ArrayOrContainerT, - actx: Optional[ArrayContext] = None) -> ArrayOrContainerT: + actx: ArrayContext | None = None) -> ArrayOrContainerT: r"""Freezes recursively by going through all components of the :class:`ArrayContainer` *ary*. @@ -651,7 +653,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: @singledispatch def with_array_context(ary: ArrayOrContainerT, - actx: Optional[ArrayContext]) -> ArrayOrContainerT: + actx: ArrayContext | None) -> ArrayOrContainerT: """ Recursively associates *actx* to all the components of *ary*. @@ -675,7 +677,7 @@ def with_array_context(ary: ArrayOrContainerT, def flatten( ary: ArrayOrContainer, actx: ArrayContext, *, - leaf_class: Optional[type] = None, + leaf_class: type | None = None, ) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`. @@ -697,7 +699,7 @@ def flatten( """ common_dtype = None - def _flatten(subary: ArrayOrContainer) -> List[Array]: + def _flatten(subary: ArrayOrContainer) -> list[Array]: nonlocal common_dtype try: @@ -875,7 +877,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: def flat_size_and_dtype( - ary: ArrayOrContainer) -> Tuple[int, Optional[np.dtype[Any]]]: + ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]: """ :returns: a tuple ``(size, dtype)`` that would be the length and :class:`numpy.dtype` of the one-dimensional array returned by @@ -911,7 +913,7 @@ def _flat_size(subary: ArrayOrContainer) -> int: # {{{ numpy conversion def from_numpy( - ary: Union[np.ndarray, ScalarLike], + ary: np.ndarray | ScalarLike, actx: ArrayContext) -> ArrayOrContainerOrScalar: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. diff --git a/arraycontext/context.py b/arraycontext/context.py index ee989ef5..28e96802 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -159,15 +159,11 @@ """ from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Mapping, - Optional, Protocol, - Tuple, TypeVar, Union, ) @@ -187,7 +183,7 @@ # {{{ typing -ScalarLike = Union[int, float, complex, np.generic] +ScalarLike = int | float | complex | np.generic SelfType = TypeVar("SelfType") @@ -206,7 +202,7 @@ class Array(Protocol): """ @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: ... @property @@ -291,7 +287,7 @@ class ArrayContext(ABC): .. automethod:: compile """ - array_types: Tuple[type, ...] = () + array_types: tuple[type, ...] = () def __init__(self) -> None: self.np = self._get_fake_numpy_namespace() @@ -304,7 +300,7 @@ def __hash__(self) -> int: raise TypeError(f"unhashable type: '{type(self).__name__}'") def zeros(self, - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], dtype: "np.dtype[Any]") -> Array: warn(f"{type(self).__name__}.zeros is deprecated and will stop " "working in 2025. Use actx.np.zeros instead.", @@ -340,7 +336,7 @@ def to_numpy(self, @abstractmethod def call_loopy(self, t_unit: "loopy.TranslationUnit", - **kwargs: Any) -> Dict[str, Array]: + **kwargs: Any) -> dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. @@ -423,7 +419,7 @@ def tag_axis(self, @memoize_method def _get_einsum_prg(self, - spec: str, arg_names: Tuple[str, ...], + spec: str, arg_names: tuple[str, ...], tagged: ToTagSetConvertible) -> "loopy.TranslationUnit": import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION @@ -454,7 +450,7 @@ def _get_einsum_prg(self, # [1] https://github.com/inducer/meshmode/issues/177 def einsum(self, spec: str, *args: Array, - arg_names: Optional[Tuple[str, ...]] = None, + arg_names: tuple[str, ...] | None = None, tagged: ToTagSetConvertible = ()) -> Array: """Computes the result of Einstein summation following the convention in :func:`numpy.einsum`. diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 26cb9db5..0b6cd727 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -27,7 +27,7 @@ THE SOFTWARE. """ -from typing import Callable, Optional, Tuple +from collections.abc import Callable import numpy as np @@ -63,8 +63,8 @@ def _get_fake_numpy_namespace(self): def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: if allowed_types is None: allowed_types = self.array_types diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 60c001a3..84d5f483 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -31,7 +31,8 @@ THE SOFTWARE. """ -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from collections.abc import Callable +from typing import TYPE_CHECKING from warnings import warn import numpy as np @@ -82,9 +83,9 @@ class PyOpenCLArrayContext(ArrayContext): def __init__(self, queue: pyopencl.CommandQueue, - allocator: Optional[pyopencl.tools.AllocatorBase] = None, - wait_event_queue_length: Optional[int] = None, - force_device_scalars: Optional[bool] = None) -> None: + allocator: pyopencl.tools.AllocatorBase | None = None, + wait_event_queue_length: int | None = None, + force_device_scalars: bool | None = None) -> None: r""" :arg wait_event_queue_length: The length of a queue of :class:`~pyopencl.Event` objects that are maintained by the @@ -132,7 +133,7 @@ def __init__(self, self._passed_force_device_scalars = force_device_scalars is not None self._wait_event_queue_length = wait_event_queue_length - self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {} + self._kernel_name_to_wait_event_queue: dict[str, list[cl.Event]] = {} if queue.device.type & cl.device_type.GPU: if allocator is None: @@ -150,7 +151,7 @@ def __init__(self, stacklevel=2) self._loopy_transform_cache: \ - Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + dict[lp.TranslationUnit, lp.TranslationUnit] = {} # TODO: Ideally this should only be `(TaggableCLArray,)`, but # that would break the logic in the downstream users. @@ -162,8 +163,8 @@ def _get_fake_numpy_namespace(self): def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: import arraycontext.impl.pyopencl.taggable_cl_array as tga diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 5c721537..ae340ca9 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -346,7 +346,7 @@ def absolute(self, a): def where(self, criterion, then, else_): def where_inner(inner_crit, inner_then, inner_else): - if isinstance(inner_crit, (bool, np.bool_)): + if isinstance(inner_crit, bool | np.bool_): return inner_then if inner_crit else inner_else return cl_array.if_positive(inner_crit != 0, inner_then, inner_else, queue=self._array_context.queue) diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index a0f3ef47..7de76113 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, FrozenSet, Optional, Tuple +from typing import Any import numpy as np @@ -23,19 +23,19 @@ class Axis(Taggable): Records the tags corresponding to a dimension of :class:`TaggableCLArray`. """ - tags: FrozenSet[Tag] + tags: frozenset[Tag] - def _with_new_tags(self, tags: FrozenSet[Tag]) -> "Axis": + def _with_new_tags(self, tags: frozenset[Tag]) -> "Axis": from dataclasses import replace return replace(self, tags=tags) @memoize -def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]: +def _construct_untagged_axes(ndim: int) -> tuple[Axis, ...]: return tuple(Axis(frozenset()) for _ in range(ndim)) -def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]: +def _unwrap_cl_array(ary: cla.Array) -> dict[str, Any]: return { "shape": ary.shape, "dtype": ary.dtype, @@ -109,7 +109,7 @@ def copy(self, queue=cla._copy_queue): return type(self)(None, tags=self.tags, axes=self.axes, **_unwrap_cl_array(ary)) - def _with_new_tags(self, tags: FrozenSet[Tag]) -> "TaggableCLArray": + def _with_new_tags(self, tags: frozenset[Tag]) -> "TaggableCLArray": return type(self)(None, tags=tags, axes=self.axes, **_unwrap_cl_array(self)) @@ -127,8 +127,8 @@ def with_tagged_axis(self, iaxis: int, def to_tagged_cl_array(ary: cla.Array, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset()) -> TaggableCLArray: + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset()) -> TaggableCLArray: """ Returns a :class:`TaggableCLArray` that is constructed from the data in *ary* along with the metadata from *axes* and *tags*. If *ary* is already a @@ -167,8 +167,8 @@ def to_tagged_cl_array(ary: cla.Array, # {{{ creation def empty(queue, shape, dtype=float, *, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset(), + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset(), order: str = "C", allocator=None) -> TaggableCLArray: if dtype is not None: @@ -181,8 +181,8 @@ def empty(queue, shape, dtype=float, *, def zeros(queue, shape, dtype=float, *, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset(), + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset(), order: str = "C", allocator=None) -> TaggableCLArray: result = empty( @@ -194,8 +194,8 @@ def zeros(queue, shape, dtype=float, *, def to_device(queue, ary, *, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset(), + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset(), allocator=None): return to_tagged_cl_array( cla.to_device(queue, ary, allocator=allocator), diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index a3518e53..4c84b6d1 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -47,16 +47,10 @@ import abc import sys +from collections.abc import Callable from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - FrozenSet, - Optional, - Tuple, - Type, - Union, ) import numpy as np @@ -91,7 +85,7 @@ # {{{ tag conversion -def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]: +def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: tags = normalize_tags(tags) name_hints = [tag for tag in tags if isinstance(tag, NameHint)] @@ -135,7 +129,7 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): def __init__( self, *, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None + compile_trace_callback: Callable[[Any, str, Any], None] | None = None ) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -148,10 +142,10 @@ def __init__( super().__init__() import pytato as pt - self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} - self._dag_transform_cache: Dict[ + self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} + self._dag_transform_cache: dict[ pt.DictOfNamedArrays, - Tuple[pt.DictOfNamedArrays, str]] = {} + tuple[pt.DictOfNamedArrays, str]] = {} if compile_trace_callback is None: def _compile_trace_callback(what, stage, ir): @@ -166,7 +160,7 @@ def _get_fake_numpy_namespace(self): return PytatoFakeNumpyNamespace(self) @abc.abstractproperty - def _frozen_array_types(self) -> Tuple[Type, ...]: + def _frozen_array_types(self) -> tuple[type, ...]: """ Returns valid frozen array types for the array context. """ @@ -256,11 +250,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ def __init__( self, queue: cl.CommandQueue, allocator=None, *, - use_memory_pool: Optional[bool] = None, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + use_memory_pool: bool | None = None, + compile_trace_callback: Callable[[Any, str, Any], None] | None = None, # do not use: only for testing - _force_svm_arg_limit: Optional[int] = None, + _force_svm_arg_limit: int | None = None, ) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -322,14 +316,14 @@ def __init__( self._force_svm_arg_limit = _force_svm_arg_limit @property - def _frozen_array_types(self) -> Tuple[Type, ...]: + def _frozen_array_types(self) -> tuple[type, ...]: import pyopencl.array as cla return (cla.Array,) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: import pytato as pt @@ -452,13 +446,13 @@ def freeze(self, array): get_cl_axes_from_pt_axes, ) - array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, TaggableCLArray] = {} - key_to_pt_arrays: Dict[str, pt.Array] = {} + array_as_dict: dict[str, cla.Array | TaggableCLArray | pt.Array] = {} + key_to_frozen_subary: dict[str, TaggableCLArray] = {} + key_to_pt_arrays: dict[str, pt.Array] = {} def _record_leaf_ary_in_dict( - key: Tuple[Any, ...], - ary: Union[cla.Array, TaggableCLArray, pt.Array]) -> None: + key: tuple[Any, ...], + ary: cla.Array | TaggableCLArray | pt.Array) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -498,7 +492,7 @@ def _record_leaf_ary_in_dict( # }}} - def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: + def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -706,7 +700,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): """ def __init__(self, - *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]] + *, compile_trace_callback: Callable[[Any, str, Any], None] | None = None) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -722,14 +716,14 @@ def __init__(self, self.array_types = (pt.Array, jnp.ndarray) @property - def _frozen_array_types(self) -> Tuple[Type, ...]: + def _frozen_array_types(self) -> tuple[type, ...]: import jax.numpy as jnp return (jnp.ndarray, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: if allowed_types is None: allowed_types = self.array_types @@ -783,12 +777,12 @@ def freeze(self, array): from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, jnp.ndarray] = {} - key_to_pt_arrays: Dict[str, pt.Array] = {} + array_as_dict: dict[str, jnp.ndarray | pt.Array] = {} + key_to_frozen_subary: dict[str, jnp.ndarray] = {} + key_to_pt_arrays: dict[str, pt.Array] = {} - def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[jnp.ndarray, pt.Array]) -> None: + def _record_leaf_ary_in_dict(key: tuple[Any, ...], + ary: jnp.ndarray | pt.Array) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -812,7 +806,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # }}} - def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: + def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 54d2cbb8..49fd3fc6 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -32,8 +32,9 @@ import abc import itertools import logging +from collections.abc import Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type +from typing import Any import numpy as np from immutabledict import immutabledict @@ -106,7 +107,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: +def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an array-container's component's key. Goals of this routine: @@ -116,7 +117,7 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: * (informal) Shorter identifiers are preferred """ def _rec_str(key: Any) -> str: - if isinstance(key, (str, int)): + if isinstance(key, str | int): return str(key) elif isinstance(key, tuple): # t in '_actx_t': stands for tuple @@ -128,11 +129,11 @@ def _rec_str(key: Any) -> str: return "_".join(_rec_str(key) for key in keys) -def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], +def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...], kwargs: Mapping[str, Any] ) -> \ - Tuple[Mapping[Tuple[Any, ...], Any], - Mapping[Tuple[Any, ...], AbstractInputDescriptor]]: + tuple[Mapping[tuple[Any, ...], Any], + Mapping[tuple[Any, ...], AbstractInputDescriptor]]: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts mappings from argument id to argument values and from argument id to @@ -140,8 +141,8 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's representation. """ - arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} - arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} + arg_id_to_arg: dict[tuple[Any, ...], Any] = {} + arg_id_to_descr: dict[tuple[Any, ...], AbstractInputDescriptor] = {} for kw, arg in itertools.chain(enumerate(args), kwargs.items()): @@ -259,7 +260,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] - program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor], + program_cache: dict[Mapping[tuple[Any, ...], AbstractInputDescriptor], "CompiledFunction"] = field(default_factory=lambda: {}) # {{{ abstract interface @@ -269,11 +270,11 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @property def compiled_function_returning_array_container_class( - self) -> Type["CompiledFunction"]: + self) -> type["CompiledFunction"]: raise NotImplementedError @property - def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: raise NotImplementedError # }}} @@ -382,11 +383,11 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> Type["CompiledFunction"]: + self) -> type["CompiledFunction"]: return CompiledPyOpenCLFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: return CompiledPyOpenCLFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @@ -481,11 +482,11 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> Type["CompiledFunction"]: + self) -> type["CompiledFunction"]: return CompiledJAXFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: return CompiledJAXFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @@ -627,10 +628,10 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): """ actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] - name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + input_id_to_name_in_program: Mapping[tuple[Any, ...], str] + output_id_to_name_in_program: Mapping[tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, frozenset[Tag]] + name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -670,9 +671,9 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): """ actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_tags: FrozenSet[Tag] - output_axes: Tuple[pt.Axis, ...] + input_id_to_name_in_program: Mapping[tuple[Any, ...], str] + output_tags: frozenset[Tag] + output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -719,10 +720,10 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): """ actx: PytatoJAXArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] - name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + input_id_to_name_in_program: Mapping[tuple[Any, ...], str] + output_id_to_name_in_program: Mapping[tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, frozenset[Tag]] + name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -750,9 +751,9 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): """ actx: PytatoJAXArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_tags: FrozenSet[Tag] - output_axes: Tuple[pt.Axis, ...] + input_id_to_name_in_program: Mapping[tuple[Any, ...], str] + output_tags: frozenset[Tag] + output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index a5582d18..0b4383c0 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,7 +23,8 @@ """ -from typing import TYPE_CHECKING, Any, Dict, Mapping, Set, Tuple +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any from pytato.array import ( AbstractResultWithNamedArrays, @@ -54,9 +55,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): """ def __init__(self) -> None: super().__init__() - self.bound_arguments: Dict[str, Any] = {} + self.bound_arguments: dict[str, Any] = {} self.vng = UniqueNameGenerator() - self.seen_inputs: Set[str] = set() + self.seen_inputs: set[str] = set() def map_data_wrapper(self, expr: DataWrapper) -> Array: if expr.name is not None: @@ -87,7 +88,7 @@ def map_placeholder(self, expr: Placeholder) -> Array: def _normalize_pt_expr( expr: DictOfNamedArrays - ) -> Tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]: + ) -> tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -102,11 +103,11 @@ def _normalize_pt_expr( return normalized_expr, normalize_mapper.bound_arguments -def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]: +def get_pt_axes_from_cl_axes(axes: tuple[ClAxis, ...]) -> tuple[PtAxis, ...]: return tuple(PtAxis(axis.tags) for axis in axes) -def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]: +def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]: return tuple(ClAxis(axis.tags) for axis in axes) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index cdd4f565..da717846 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -27,7 +27,8 @@ THE SOFTWARE. """ -from typing import ClassVar, Mapping +from collections.abc import Mapping +from typing import ClassVar import numpy as np diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index c778154d..44dc862e 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -31,7 +31,8 @@ THE SOFTWARE. """ -from typing import Any, Callable, Dict, Sequence, Type, Union +from collections.abc import Callable, Sequence +from typing import Any from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -245,7 +246,7 @@ def __str__(self): _ARRAY_CONTEXT_FACTORY_REGISTRY: \ - Dict[str, Type[PytestArrayContextFactory]] = { + dict[str, type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, @@ -256,7 +257,7 @@ def __str__(self): def register_pytest_array_context_factory( name: str, - factory: Type[PytestArrayContextFactory]) -> None: + factory: type[PytestArrayContextFactory]) -> None: if name in _ARRAY_CONTEXT_FACTORY_REGISTRY: raise ValueError(f"factory '{name}' already exists") @@ -268,7 +269,7 @@ def register_pytest_array_context_factory( # {{{ pytest integration def pytest_generate_tests_for_array_contexts( - factories: Sequence[Union[str, Type[PytestArrayContextFactory]]], *, + factories: Sequence[str | type[PytestArrayContextFactory]], *, factory_arg_name: str = "actx_factory", ) -> Callable[[Any], None]: """Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`. diff --git a/arraycontext/version.py b/arraycontext/version.py index 31baea05..05fe8763 100644 --- a/arraycontext/version.py +++ b/arraycontext/version.py @@ -1,8 +1,7 @@ from importlib import metadata -from typing import Tuple -def _parse_version(version: str) -> Tuple[Tuple[int, ...], str]: +def _parse_version(version: str) -> tuple[tuple[int, ...], str]: import re m = re.match("^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT) diff --git a/pyproject.toml b/pyproject.toml index 7244418a..48991910 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,8 +97,6 @@ extend-ignore = [ "E221", # multiple spaces before operator "E226", # missing whitespace around arithmetic operator "E402", # module-level import not at top of file - "UP006", # updated annotations due to __future__ import - "UP007", # updated annotations due to __future__ import ] [tool.ruff.lint.flake8-quotes] diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index dfef7339..311ba1d5 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -23,7 +23,6 @@ import logging from dataclasses import dataclass from functools import partial -from typing import Union import numpy as np import pytest @@ -216,9 +215,9 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @dataclass(frozen=True) class MyContainer: name: str - mass: Union[DOFArray, np.ndarray] + mass: DOFArray | np.ndarray momentum: np.ndarray - enthalpy: Union[DOFArray, np.ndarray] + enthalpy: DOFArray | np.ndarray __array_ufunc__ = None @@ -241,9 +240,9 @@ def array_context(self): @dataclass(frozen=True) class MyContainerDOFBcast: name: str - mass: Union[DOFArray, np.ndarray] + mass: DOFArray | np.ndarray momentum: np.ndarray - enthalpy: Union[DOFArray, np.ndarray] + enthalpy: DOFArray | np.ndarray @property def array_context(self): @@ -255,7 +254,7 @@ def array_context(self): def _get_test_containers(actx, ambient_dim=2, shapes=50_000): from numbers import Number - if isinstance(shapes, (Number, tuple)): + if isinstance(shapes, Number | tuple): shapes = [shapes] x = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) @@ -1074,7 +1073,7 @@ def test_flatten_array_container(actx_factory, shapes): # {{{ complex to real - if isinstance(shapes, (int, tuple)): + if isinstance(shapes, int | tuple): shapes = [shapes] ary = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) @@ -1558,7 +1557,7 @@ def test_to_numpy_on_frozen_arrays(actx_factory): def test_tagging(actx_factory): actx = actx_factory() - if isinstance(actx, (NumpyArrayContext, EagerJAXArrayContext)): + if isinstance(actx, NumpyArrayContext | EagerJAXArrayContext): pytest.skip(f"{type(actx)} has no tagging support") from pytools.tag import Tag diff --git a/test/test_utils.py b/test/test_utils.py index 04817d6a..7ee5ad30 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -49,7 +49,6 @@ def test_pt_actx_key_stringification_uniqueness(): def test_dataclass_array_container() -> None: from dataclasses import dataclass, field - from typing import Optional from arraycontext import dataclass_array_container @@ -71,7 +70,7 @@ class ArrayContainerWithStringTypes: @dataclass class ArrayContainerWithOptional: x: np.ndarray - y: Optional[np.ndarray] + y: np.ndarray | None with pytest.raises(TypeError): # NOTE: cannot have wrapped annotations (here by `Optional`) @@ -113,7 +112,6 @@ class ArrayContainerWithArray: def test_dataclass_container_unions() -> None: from dataclasses import dataclass - from typing import Union from arraycontext import Array, dataclass_array_container @@ -122,7 +120,7 @@ def test_dataclass_container_unions() -> None: @dataclass class ArrayContainerWithUnion: x: np.ndarray - y: Union[np.ndarray, Array] + y: np.ndarray | Array dataclass_array_container(ArrayContainerWithUnion) @@ -133,7 +131,7 @@ class ArrayContainerWithUnion: @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray - y: Union[np.ndarray, float] + y: np.ndarray | float with pytest.raises(TypeError): # NOTE: float is not an ArrayContainer, so y should fail