diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6e3eab36..cd011830 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,11 +48,6 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 - # NOTE: jax>=0.4.31 requires python 3.10 and uses pattern matching - # which conflicts with our mypy.python_version = '3.8' setting - CONDA_ENVIRONMENT=.test-conda-env-py3.yml - sed -i "s/jax/jax<0.4.31/" "$CONDA_ENVIRONMENT" - build_py_project_in_conda_env python -m pip install mypy pytest ./run-mypy.sh diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index fa09197f..f7cf75ff 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -107,9 +107,7 @@ Pylint: Mypy: script: | - # NOTE: jax>=0.4.31 requires python 3.10 and uses pattern matching - # which conflicts with our mypy.python_version = '3.8' setting - EXTRA_INSTALL="mypy pytest jax[cpu]<0.4.31" + EXTRA_INSTALL="mypy pytest" curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index bb18e986..6c4fb671 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -79,23 +79,14 @@ THE SOFTWARE. """ +from collections.abc import Hashable, Sequence from functools import singledispatch -from typing import ( - TYPE_CHECKING, - Any, - Hashable, - Optional, - Protocol, - Sequence, - Tuple, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy import numpy as np -from typing_extensions import TypeAlias from arraycontext.context import ArrayContext @@ -162,7 +153,7 @@ class NotAnArrayContainerError(TypeError): SerializationKey: TypeAlias = Hashable -SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]] +SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]] @singledispatch @@ -249,7 +240,7 @@ def is_array_container(ary: Any) -> bool: @singledispatch -def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: +def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None: """Retrieves the :class:`ArrayContext` from the container, if any. This function is not recursive, so it will only search at the root level @@ -303,7 +294,7 @@ def _deserialize_ndarray_container( # type: ignore[misc] # {{{ get_container_context_recursively def get_container_context_recursively_opt( - ary: ArrayContainer) -> Optional[ArrayContext]: + ary: ArrayContainer) -> ArrayContext | None: """Walks the :class:`ArrayContainer` hierarchy to find an :class:`ArrayContext` associated with it. @@ -337,7 +328,7 @@ def get_container_context_recursively_opt( return actx -def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]: +def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None: """Walks the :class:`ArrayContainer` hierarchy to find an :class:`ArrayContext` associated with it. diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 9366b260..22572dc8 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -34,7 +34,8 @@ """ import enum -from typing import Any, Callable, Optional, Tuple, TypeVar, Union +from collections.abc import Callable +from typing import Any, TypeVar from warnings import warn import numpy as np @@ -90,7 +91,7 @@ class _OpClass(enum.Enum): ] -def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: +def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str: if isinstance(arg1, tuple): arg1_entry, arg1_container = arg1 return (f"{op_str.format(arg1_entry)} " @@ -100,20 +101,14 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_binary_op_str(op_str: str, - arg1: Union[Tuple[str, str], str], - arg2: Union[Tuple[str, str], str]) -> str: + arg1: tuple[str, str] | str, + arg2: tuple[str, str] | str) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): - import sys - if sys.version_info >= (3, 10): - strict_arg = ", strict=__debug__" - else: - strict_arg = "" - arg1_entry, arg1_container = arg1 arg2_entry, arg2_container = arg2 return (f"{op_str.format(arg1_entry, arg2_entry)} " f"for {arg1_entry}, {arg2_entry} " - f"in zip({arg1_container}, {arg2_container}{strict_arg})") + f"in zip({arg1_container}, {arg2_container}, strict=__debug__)") elif isinstance(arg1, tuple): arg1_entry, arg1_container = arg1 @@ -160,23 +155,23 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet def with_container_arithmetic( *, - number_bcasts_across: Optional[bool] = None, - bcasts_across_obj_array: Optional[bool] = None, - container_types_bcast_across: Optional[Tuple[type, ...]] = None, + number_bcasts_across: bool | None = None, + bcasts_across_obj_array: bool | None = None, + container_types_bcast_across: tuple[type, ...] | None = None, arithmetic: bool = True, matmul: bool = False, bitwise: bool = False, shift: bool = False, - _cls_has_array_context_attr: Optional[bool] = None, - eq_comparison: Optional[bool] = None, - rel_comparison: Optional[bool] = None, + _cls_has_array_context_attr: bool | None = None, + eq_comparison: bool | None = None, + rel_comparison: bool | None = None, # deprecated: - bcast_number: Optional[bool] = None, - bcast_obj_array: Optional[bool] = None, + bcast_number: bool | None = None, + bcast_obj_array: bool | None = None, bcast_numpy_array: bool = False, - _bcast_actx_array_type: Optional[bool] = None, - bcast_container_types: Optional[Tuple[type, ...]] = None, + _bcast_actx_array_type: bool | None = None, + bcast_container_types: tuple[type, ...] | None = None, ) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. @@ -482,7 +477,7 @@ def same_key(k1: T, k2: T) -> T: assert k1 == k2 return k1 - def tup_str(t: Tuple[str, ...]) -> str: + def tup_str(t: tuple[str, ...]) -> str: if not t: return "()" else: @@ -544,7 +539,8 @@ def {fname}(arg1): _format_binary_op_str(op_str, expr_arg1, expr_arg2) for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip( cls._serialize_init_arrays_code("arg1").items(), - cls._serialize_init_arrays_code("arg2").items()) + cls._serialize_init_arrays_code("arg2").items(), + strict=True) }) bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 492f0c92..ec4c37f4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -31,7 +31,7 @@ """ from dataclasses import Field, fields, is_dataclass -from typing import Tuple, Union, get_args, get_origin +from typing import Union, get_args, get_origin from arraycontext.container import is_array_container_type @@ -100,7 +100,7 @@ def is_array_field(f: Field) -> bool: _BaseGenericAlias, _SpecialForm, ) - if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)): + if isinstance(f.type, _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( f"Typing annotation not supported on field '{f.name}': " @@ -125,8 +125,8 @@ def is_array_field(f: Field) -> bool: def inject_dataclass_serialization( cls: type, - array_fields: Tuple[Field, ...], - non_array_fields: Tuple[Field, ...]) -> type: + array_fields: tuple[Field, ...], + non_array_fields: tuple[Field, ...]) -> type: """Implements :func:`~arraycontext.serialize_container` and :func:`~arraycontext.deserialize_container` for the given dataclass *cls*. diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 5f94ad63..62f6354c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -70,8 +70,9 @@ THE SOFTWARE. """ +from collections.abc import Callable, Iterable from functools import partial, singledispatch, update_wrapper -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, cast from warnings import warn import numpy as np @@ -100,7 +101,7 @@ def _map_array_container_impl( f: Callable[[ArrayOrContainer], ArrayOrContainer], ary: ArrayOrContainer, *, - leaf_cls: Optional[type] = None, + leaf_cls: type | None = None, recursive: bool = False) -> ArrayOrContainer: """Helper for :func:`rec_map_array_container`. @@ -129,9 +130,9 @@ def rec(_ary: ArrayOrContainer) -> ArrayOrContainer: def _multimap_array_container_impl( f: Callable[..., Any], *args: Any, - reduce_func: Optional[Callable[ - [ArrayContainer, Iterable[Tuple[Any, Any]]], Any]] = None, - leaf_cls: Optional[type] = None, + reduce_func: ( + Callable[[ArrayContainer, Iterable[tuple[Any, Any]]], Any] | None) = None, + leaf_cls: type | None = None, recursive: bool = False) -> ArrayOrContainer: """Helper for :func:`rec_multimap_array_container`. @@ -164,10 +165,11 @@ def rec(*_args: Any) -> Any: for subarys in zip( iterable_template, - *[serialize_container(_args[i]) for i in container_indices[1:]] + *[serialize_container(_args[i]) for i in container_indices[1:]], + strict=True ): key = None - for i, (subkey, subary) in zip(container_indices, subarys): + for i, (subkey, subary) in zip(container_indices, subarys, strict=True): if key is None: key = subkey else: @@ -183,7 +185,7 @@ def rec(*_args: Any) -> Any: # {{{ find all containers in the argument list - container_indices: List[int] = [] + container_indices: list[int] = [] for i, arg in enumerate(args): if type(arg) is leaf_cls: @@ -244,7 +246,7 @@ def stringify_array_container_tree(ary: ArrayOrContainer) -> str: :returns: a string for an ASCII tree representation of the array container, similar to `asciitree `__. """ - def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None: + def rec(lines: list[str], ary_: ArrayOrContainerT, level: int) -> None: try: iterable = serialize_container(ary_) except NotAnArrayContainerError: @@ -307,7 +309,7 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: def rec_map_array_container( f: Callable[[Any], Any], ary: ArrayOrContainer, - leaf_class: Optional[type] = None) -> ArrayOrContainer: + leaf_class: type | None = None) -> ArrayOrContainer: r"""Applies *f* recursively to an :class:`ArrayContainer`. For a non-recursive version see :func:`map_array_container`. @@ -319,12 +321,12 @@ def rec_map_array_container( def mapped_over_array_containers( - f: Optional[Callable[[ArrayOrContainer], ArrayOrContainer]] = None, - leaf_class: Optional[type] = None) -> Union[ - Callable[[ArrayOrContainer], ArrayOrContainer], - Callable[ + f: Callable[[ArrayOrContainer], ArrayOrContainer] | None = None, + leaf_class: type | None = None) -> ( + Callable[[ArrayOrContainer], ArrayOrContainer] + | Callable[ [Callable[[Any], Any]], - Callable[[ArrayOrContainer], ArrayOrContainer]]]: + Callable[[ArrayOrContainer], ArrayOrContainer]]): """Decorator around :func:`rec_map_array_container`.""" def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[ [ArrayOrContainer], ArrayOrContainer]: @@ -340,7 +342,7 @@ def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[ def rec_multimap_array_container( f: Callable[..., Any], *args: Any, - leaf_class: Optional[type] = None) -> Any: + leaf_class: type | None = None) -> Any: r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s. For a non-recursive version see :func:`multimap_array_container`. @@ -353,10 +355,10 @@ def rec_multimap_array_container( def multimapped_over_array_containers( - f: Optional[Callable[..., Any]] = None, - leaf_class: Optional[type] = None) -> Union[ - Callable[..., Any], - Callable[[Callable[..., Any]], Callable[..., Any]]]: + f: Callable[..., Any] | None = None, + leaf_class: type | None = None) -> ( + Callable[..., Any] + | Callable[[Callable[..., Any]], Callable[..., Any]]): """Decorator around :func:`rec_multimap_array_container`.""" def decorator(g: Callable[..., Any]) -> Callable[..., Any]: # can't use functools.partial, because its result is insufficiently @@ -403,7 +405,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT], + f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -412,7 +414,7 @@ def rec_keyed_map_array_container( the current array. """ - def rec(keys: Tuple[SerializationKey, ...], + def rec(keys: tuple[SerializationKey, ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) @@ -469,7 +471,7 @@ def multimap_reduce_array_container( # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier def _reduce_wrapper( - ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]] + ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]] ) -> Array: return reduce_func([subary for _, subary in iterable]) @@ -482,7 +484,7 @@ def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], ary: ArrayOrContainer, - leaf_class: Optional[type] = None) -> ArrayOrContainer: + leaf_class: type | None = None) -> ArrayOrContainer: """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of *ary* @@ -540,7 +542,7 @@ def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], *args: Any, - leaf_class: Optional[type] = None) -> ArrayOrContainer: + leaf_class: type | None = None) -> ArrayOrContainer: r"""Perform a map-reduce over multiple array containers recursively. :param reduce_func: callable used to reduce over the components of any @@ -559,7 +561,7 @@ def rec_multimap_reduce_array_container( # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier def _reduce_wrapper( - ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]) -> Any: + ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]) -> Any: return reduce_func([subary for _, subary in iterable]) return _multimap_array_container_impl( @@ -573,7 +575,7 @@ def _reduce_wrapper( def freeze( ary: ArrayOrContainerT, - actx: Optional[ArrayContext] = None) -> ArrayOrContainerT: + actx: ArrayContext | None = None) -> ArrayOrContainerT: r"""Freezes recursively by going through all components of the :class:`ArrayContainer` *ary*. @@ -650,7 +652,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: @singledispatch def with_array_context(ary: ArrayOrContainerT, - actx: Optional[ArrayContext]) -> ArrayOrContainerT: + actx: ArrayContext | None) -> ArrayOrContainerT: """ Recursively associates *actx* to all the components of *ary*. @@ -674,7 +676,7 @@ def with_array_context(ary: ArrayOrContainerT, def flatten( ary: ArrayOrContainer, actx: ArrayContext, *, - leaf_class: Optional[type] = None, + leaf_class: type | None = None, ) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`. @@ -696,7 +698,7 @@ def flatten( """ common_dtype = None - def _flatten(subary: ArrayOrContainer) -> List[Array]: + def _flatten(subary: ArrayOrContainer) -> list[Array]: nonlocal common_dtype try: @@ -874,7 +876,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: def flat_size_and_dtype( - ary: ArrayOrContainer) -> Tuple[int, Optional[np.dtype[Any]]]: + ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]: """ :returns: a tuple ``(size, dtype)`` that would be the length and :class:`numpy.dtype` of the one-dimensional array returned by @@ -910,7 +912,7 @@ def _flat_size(subary: ArrayOrContainer) -> int: # {{{ numpy conversion def from_numpy( - ary: Union[np.ndarray, ScalarLike], + ary: np.ndarray | ScalarLike, actx: ArrayContext) -> ArrayOrContainerOrScalar: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. diff --git a/arraycontext/context.py b/arraycontext/context.py index ee989ef5..398f8aa3 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -159,18 +159,8 @@ """ from abc import ABC, abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Mapping, - Optional, - Protocol, - Tuple, - TypeVar, - Union, -) +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union from warnings import warn import numpy as np @@ -187,7 +177,7 @@ # {{{ typing -ScalarLike = Union[int, float, complex, np.generic] +ScalarLike = int | float | complex | np.generic SelfType = TypeVar("SelfType") @@ -206,7 +196,7 @@ class Array(Protocol): """ @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: ... @property @@ -291,7 +281,7 @@ class ArrayContext(ABC): .. automethod:: compile """ - array_types: Tuple[type, ...] = () + array_types: tuple[type, ...] = () def __init__(self) -> None: self.np = self._get_fake_numpy_namespace() @@ -304,7 +294,7 @@ def __hash__(self) -> int: raise TypeError(f"unhashable type: '{type(self).__name__}'") def zeros(self, - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], dtype: "np.dtype[Any]") -> Array: warn(f"{type(self).__name__}.zeros is deprecated and will stop " "working in 2025. Use actx.np.zeros instead.", @@ -340,7 +330,7 @@ def to_numpy(self, @abstractmethod def call_loopy(self, t_unit: "loopy.TranslationUnit", - **kwargs: Any) -> Dict[str, Array]: + **kwargs: Any) -> dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. @@ -423,7 +413,7 @@ def tag_axis(self, @memoize_method def _get_einsum_prg(self, - spec: str, arg_names: Tuple[str, ...], + spec: str, arg_names: tuple[str, ...], tagged: ToTagSetConvertible) -> "loopy.TranslationUnit": import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION @@ -454,7 +444,7 @@ def _get_einsum_prg(self, # [1] https://github.com/inducer/meshmode/issues/177 def einsum(self, spec: str, *args: Array, - arg_names: Optional[Tuple[str, ...]] = None, + arg_names: tuple[str, ...] | None = None, tagged: ToTagSetConvertible = ()) -> Array: """Computes the result of Einstein summation following the convention in :func:`numpy.einsum`. diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 26cb9db5..0b6cd727 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -27,7 +27,7 @@ THE SOFTWARE. """ -from typing import Callable, Optional, Tuple +from collections.abc import Callable import numpy as np @@ -63,8 +63,8 @@ def _get_fake_numpy_namespace(self): def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: if allowed_types is None: allowed_types = self.array_types diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index bc9481e3..094e8cf2 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -187,7 +187,7 @@ def rec_equal(x, y): [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], + in zip(serialized_x, serialized_y, strict=True)], true_ary) return rec_equal(a, b) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index b305717e..f345edc9 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -21,7 +21,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from functools import partial, reduce +from typing import cast import numpy as np @@ -143,13 +145,12 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: else: if len(serialized_x) != len(serialized_y): return false_ary - return reduce( - np.logical_and, + return np.logical_and.reduce( [(true_ary if kx_i == ky_i else false_ary) - and self.array_equal(x_i, y_i) + and cast(np.ndarray, self.array_equal(x_i, y_i)) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], - true_ary) + in zip(serialized_x, serialized_y, strict=True)], + initial=true_ary) def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 60c001a3..84d5f483 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -31,7 +31,8 @@ THE SOFTWARE. """ -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from collections.abc import Callable +from typing import TYPE_CHECKING from warnings import warn import numpy as np @@ -82,9 +83,9 @@ class PyOpenCLArrayContext(ArrayContext): def __init__(self, queue: pyopencl.CommandQueue, - allocator: Optional[pyopencl.tools.AllocatorBase] = None, - wait_event_queue_length: Optional[int] = None, - force_device_scalars: Optional[bool] = None) -> None: + allocator: pyopencl.tools.AllocatorBase | None = None, + wait_event_queue_length: int | None = None, + force_device_scalars: bool | None = None) -> None: r""" :arg wait_event_queue_length: The length of a queue of :class:`~pyopencl.Event` objects that are maintained by the @@ -132,7 +133,7 @@ def __init__(self, self._passed_force_device_scalars = force_device_scalars is not None self._wait_event_queue_length = wait_event_queue_length - self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {} + self._kernel_name_to_wait_event_queue: dict[str, list[cl.Event]] = {} if queue.device.type & cl.device_type.GPU: if allocator is None: @@ -150,7 +151,7 @@ def __init__(self, stacklevel=2) self._loopy_transform_cache: \ - Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + dict[lp.TranslationUnit, lp.TranslationUnit] = {} # TODO: Ideally this should only be `(TaggableCLArray,)`, but # that would break the logic in the downstream users. @@ -162,8 +163,8 @@ def _get_fake_numpy_namespace(self): def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: import arraycontext.impl.pyopencl.taggable_cl_array as tga diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index ac792452..ae340ca9 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -236,7 +236,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], + in zip(serialized_x, serialized_y, strict=True)], true_ary) return rec_equal(a, b) @@ -346,7 +346,7 @@ def absolute(self, a): def where(self, criterion, then, else_): def where_inner(inner_crit, inner_then, inner_else): - if isinstance(inner_crit, (bool, np.bool_)): + if isinstance(inner_crit, bool | np.bool_): return inner_then if inner_crit else inner_else return cl_array.if_positive(inner_crit != 0, inner_then, inner_else, queue=self._array_context.queue) diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index a0f3ef47..7de76113 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, FrozenSet, Optional, Tuple +from typing import Any import numpy as np @@ -23,19 +23,19 @@ class Axis(Taggable): Records the tags corresponding to a dimension of :class:`TaggableCLArray`. """ - tags: FrozenSet[Tag] + tags: frozenset[Tag] - def _with_new_tags(self, tags: FrozenSet[Tag]) -> "Axis": + def _with_new_tags(self, tags: frozenset[Tag]) -> "Axis": from dataclasses import replace return replace(self, tags=tags) @memoize -def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]: +def _construct_untagged_axes(ndim: int) -> tuple[Axis, ...]: return tuple(Axis(frozenset()) for _ in range(ndim)) -def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]: +def _unwrap_cl_array(ary: cla.Array) -> dict[str, Any]: return { "shape": ary.shape, "dtype": ary.dtype, @@ -109,7 +109,7 @@ def copy(self, queue=cla._copy_queue): return type(self)(None, tags=self.tags, axes=self.axes, **_unwrap_cl_array(ary)) - def _with_new_tags(self, tags: FrozenSet[Tag]) -> "TaggableCLArray": + def _with_new_tags(self, tags: frozenset[Tag]) -> "TaggableCLArray": return type(self)(None, tags=tags, axes=self.axes, **_unwrap_cl_array(self)) @@ -127,8 +127,8 @@ def with_tagged_axis(self, iaxis: int, def to_tagged_cl_array(ary: cla.Array, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset()) -> TaggableCLArray: + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset()) -> TaggableCLArray: """ Returns a :class:`TaggableCLArray` that is constructed from the data in *ary* along with the metadata from *axes* and *tags*. If *ary* is already a @@ -167,8 +167,8 @@ def to_tagged_cl_array(ary: cla.Array, # {{{ creation def empty(queue, shape, dtype=float, *, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset(), + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset(), order: str = "C", allocator=None) -> TaggableCLArray: if dtype is not None: @@ -181,8 +181,8 @@ def empty(queue, shape, dtype=float, *, def zeros(queue, shape, dtype=float, *, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset(), + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset(), order: str = "C", allocator=None) -> TaggableCLArray: result = empty( @@ -194,8 +194,8 @@ def zeros(queue, shape, dtype=float, *, def to_device(queue, ary, *, - axes: Optional[Tuple[Axis, ...]] = None, - tags: FrozenSet[Tag] = frozenset(), + axes: tuple[Axis, ...] | None = None, + tags: frozenset[Tag] = frozenset(), allocator=None): return to_tagged_cl_array( cla.to_device(queue, ary, allocator=allocator), diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 099738a9..1d36971c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -47,17 +47,8 @@ import abc import sys -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - FrozenSet, - Optional, - Tuple, - Type, - Union, -) +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import numpy as np @@ -91,7 +82,7 @@ # {{{ tag conversion -def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]: +def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: tags = normalize_tags(tags) name_hints = [tag for tag in tags if isinstance(tag, NameHint)] @@ -135,7 +126,7 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): def __init__( self, *, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None + compile_trace_callback: Callable[[Any, str, Any], None] | None = None ) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -148,10 +139,10 @@ def __init__( super().__init__() import pytato as pt - self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} - self._dag_transform_cache: Dict[ + self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} + self._dag_transform_cache: dict[ pt.DictOfNamedArrays, - Tuple[pt.DictOfNamedArrays, str]] = {} + tuple[pt.DictOfNamedArrays, str]] = {} if compile_trace_callback is None: def _compile_trace_callback(what, stage, ir): @@ -166,7 +157,7 @@ def _get_fake_numpy_namespace(self): return PytatoFakeNumpyNamespace(self) @abc.abstractproperty - def _frozen_array_types(self) -> Tuple[Type, ...]: + def _frozen_array_types(self) -> tuple[type, ...]: """ Returns valid frozen array types for the array context. """ @@ -256,11 +247,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ def __init__( self, queue: cl.CommandQueue, allocator=None, *, - use_memory_pool: Optional[bool] = None, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + use_memory_pool: bool | None = None, + compile_trace_callback: Callable[[Any, str, Any], None] | None = None, # do not use: only for testing - _force_svm_arg_limit: Optional[int] = None, + _force_svm_arg_limit: int | None = None, ) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -322,14 +313,14 @@ def __init__( self._force_svm_arg_limit = _force_svm_arg_limit @property - def _frozen_array_types(self) -> Tuple[Type, ...]: + def _frozen_array_types(self) -> tuple[type, ...]: import pyopencl.array as cla return (cla.Array,) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: import pytato as pt @@ -452,13 +443,13 @@ def freeze(self, array): get_cl_axes_from_pt_axes, ) - array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, TaggableCLArray] = {} - key_to_pt_arrays: Dict[str, pt.Array] = {} + array_as_dict: dict[str, cla.Array | TaggableCLArray | pt.Array] = {} + key_to_frozen_subary: dict[str, TaggableCLArray] = {} + key_to_pt_arrays: dict[str, pt.Array] = {} def _record_leaf_ary_in_dict( - key: Tuple[Any, ...], - ary: Union[cla.Array, TaggableCLArray, pt.Array]) -> None: + key: tuple[Any, ...], + ary: cla.Array | TaggableCLArray | pt.Array) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -498,7 +489,7 @@ def _record_leaf_ary_in_dict( # }}} - def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: + def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -685,7 +676,7 @@ def preprocess_arg(name, arg): return pt.einsum(spec, *[ preprocess_arg(name, arg) - for name, arg in zip(arg_names, args) + for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged)) def clone(self): @@ -706,8 +697,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): """ def __init__(self, - *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]] - = None) -> None: + *, + compile_trace_callback: Callable[[Any, str, Any], None] | None = None, + ) -> None: """ :arg compile_trace_callback: A function of three arguments *(what, stage, ir)*, where *what* identifies the object @@ -722,14 +714,14 @@ def __init__(self, self.array_types = (pt.Array, jnp.ndarray) @property - def _frozen_array_types(self) -> Tuple[Type, ...]: + def _frozen_array_types(self) -> tuple[type, ...]: import jax.numpy as jnp return (jnp.ndarray, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, - allowed_types: Optional[Tuple[type, ...]] = None, *, - default_scalar: Optional[ScalarLike] = None, + allowed_types: tuple[type, ...] | None = None, *, + default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: if allowed_types is None: allowed_types = self.array_types @@ -783,12 +775,12 @@ def freeze(self, array): from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, jnp.ndarray] = {} - key_to_pt_arrays: Dict[str, pt.Array] = {} + array_as_dict: dict[str, jnp.ndarray | pt.Array] = {} + key_to_frozen_subary: dict[str, jnp.ndarray] = {} + key_to_pt_arrays: dict[str, pt.Array] = {} - def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[jnp.ndarray, pt.Array]) -> None: + def _record_leaf_ary_in_dict(key: tuple[Any, ...], + ary: jnp.ndarray | pt.Array) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -812,7 +804,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # }}} - def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: + def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -913,7 +905,7 @@ def preprocess_arg(name, arg): return pt.einsum(spec, *[ preprocess_arg(name, arg) - for name, arg in zip(arg_names, args) + for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged)) def clone(self): diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 54d2cbb8..952761bf 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -32,8 +32,9 @@ import abc import itertools import logging +from collections.abc import Callable, Hashable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type +from typing import Any import numpy as np from immutabledict import immutabledict @@ -106,7 +107,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: +def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an array-container's component's key. Goals of this routine: @@ -116,7 +117,7 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: * (informal) Shorter identifiers are preferred """ def _rec_str(key: Any) -> str: - if isinstance(key, (str, int)): + if isinstance(key, str | int): return str(key) elif isinstance(key, tuple): # t in '_actx_t': stands for tuple @@ -128,11 +129,11 @@ def _rec_str(key: Any) -> str: return "_".join(_rec_str(key) for key in keys) -def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], +def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...], kwargs: Mapping[str, Any] ) -> \ - Tuple[Mapping[Tuple[Any, ...], Any], - Mapping[Tuple[Any, ...], AbstractInputDescriptor]]: + tuple[Mapping[tuple[Hashable, ...], Any], + Mapping[tuple[Hashable, ...], AbstractInputDescriptor]]: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts mappings from argument id to argument values and from argument id to @@ -140,8 +141,8 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's representation. """ - arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} - arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} + arg_id_to_arg: dict[tuple[Hashable, ...], Any] = {} + arg_id_to_descr: dict[tuple[Hashable, ...], AbstractInputDescriptor] = {} for kw, arg in itertools.chain(enumerate(args), kwargs.items()): @@ -259,7 +260,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] - program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor], + program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor], "CompiledFunction"] = field(default_factory=lambda: {}) # {{{ abstract interface @@ -269,11 +270,11 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @property def compiled_function_returning_array_container_class( - self) -> Type["CompiledFunction"]: + self) -> type["CompiledFunction"]: raise NotImplementedError @property - def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: raise NotImplementedError # }}} @@ -382,11 +383,11 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> Type["CompiledFunction"]: + self) -> type["CompiledFunction"]: return CompiledPyOpenCLFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: return CompiledPyOpenCLFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @@ -481,11 +482,11 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> Type["CompiledFunction"]: + self) -> type["CompiledFunction"]: return CompiledJAXFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: return CompiledJAXFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @@ -627,10 +628,10 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): """ actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] - name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] + output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] + name_in_program_to_tags: Mapping[str, frozenset[Tag]] + name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -670,9 +671,9 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): """ actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_tags: FrozenSet[Tag] - output_axes: Tuple[pt.Axis, ...] + input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] + output_tags: frozenset[Tag] + output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -719,10 +720,10 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): """ actx: PytatoJAXArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] - name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] + output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] + name_in_program_to_tags: Mapping[str, frozenset[Tag]] + name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -750,9 +751,9 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): """ actx: PytatoJAXArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] - output_tags: FrozenSet[Tag] - output_axes: Tuple[pt.Axis, ...] + input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] + output_tags: frozenset[Tag] + output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index c6508e3a..0692eb7e 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -203,7 +203,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], + in zip(serialized_x, serialized_y, strict=True)], true_ary) return cast(Array, rec_equal(a, b)) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index a5582d18..d0c80a33 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -22,8 +22,8 @@ THE SOFTWARE. """ - -from typing import TYPE_CHECKING, Any, Dict, Mapping, Set, Tuple +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, cast from pytato.array import ( AbstractResultWithNamedArrays, @@ -54,9 +54,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): """ def __init__(self) -> None: super().__init__() - self.bound_arguments: Dict[str, Any] = {} + self.bound_arguments: dict[str, Any] = {} self.vng = UniqueNameGenerator() - self.seen_inputs: Set[str] = set() + self.seen_inputs: set[str] = set() def map_data_wrapper(self, expr: DataWrapper) -> Array: if expr.name is not None: @@ -71,7 +71,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.bound_arguments[name] = expr.data return make_placeholder( name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s + shape=tuple(cast(Array, self.rec(s)) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, axes=expr.axes, @@ -87,7 +87,7 @@ def map_placeholder(self, expr: Placeholder) -> Array: def _normalize_pt_expr( expr: DictOfNamedArrays - ) -> Tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]: + ) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -99,14 +99,15 @@ def _normalize_pt_expr( """ normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalized_expr = normalize_mapper(expr) + assert isinstance(normalized_expr, AbstractResultWithNamedArrays) return normalized_expr, normalize_mapper.bound_arguments -def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]: +def get_pt_axes_from_cl_axes(axes: tuple[ClAxis, ...]) -> tuple[PtAxis, ...]: return tuple(PtAxis(axis.tags) for axis in axes) -def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]: +def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]: return tuple(ClAxis(axis.tags) for axis in axes) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index 1bee3eb0..da717846 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -27,7 +27,8 @@ THE SOFTWARE. """ -from typing import ClassVar, Mapping +from collections.abc import Mapping +from typing import ClassVar import numpy as np @@ -88,7 +89,7 @@ def get(c_name, nargs, naxes): from islpy import make_zero_and_vars v = make_zero_and_vars(var_names, params=size_names) domain = v[0].domain() - for vname, sname in zip(var_names, size_names): + for vname, sname in zip(var_names, size_names, strict=True): domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname]) domain_bset, = domain.get_basic_sets() diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index c778154d..f1f62a71 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -31,7 +31,8 @@ THE SOFTWARE. """ -from typing import Any, Callable, Dict, Sequence, Type, Union +from collections.abc import Callable, Sequence +from typing import Any from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -244,19 +245,18 @@ def __str__(self): # }}} -_ARRAY_CONTEXT_FACTORY_REGISTRY: \ - Dict[str, Type[PytestArrayContextFactory]] = { - "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, - "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, - "pytato:jax": _PytestPytatoJaxArrayContextFactory, - "eagerjax": _PytestEagerJaxArrayContextFactory, - "numpy": _PytestNumpyArrayContextFactory, - } +_ARRAY_CONTEXT_FACTORY_REGISTRY: dict[str, type[PytestArrayContextFactory]] = { + "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, + "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, + "pytato:jax": _PytestPytatoJaxArrayContextFactory, + "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, + } def register_pytest_array_context_factory( name: str, - factory: Type[PytestArrayContextFactory]) -> None: + factory: type[PytestArrayContextFactory]) -> None: if name in _ARRAY_CONTEXT_FACTORY_REGISTRY: raise ValueError(f"factory '{name}' already exists") @@ -268,7 +268,7 @@ def register_pytest_array_context_factory( # {{{ pytest integration def pytest_generate_tests_for_array_contexts( - factories: Sequence[Union[str, Type[PytestArrayContextFactory]]], *, + factories: Sequence[str | type[PytestArrayContextFactory]], *, factory_arg_name: str = "actx_factory", ) -> Callable[[Any], None]: """Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`. diff --git a/arraycontext/version.py b/arraycontext/version.py index 31baea05..05fe8763 100644 --- a/arraycontext/version.py +++ b/arraycontext/version.py @@ -1,8 +1,7 @@ from importlib import metadata -from typing import Tuple -def _parse_version(version: str) -> Tuple[Tuple[int, ...], str]: +def _parse_version(version: str) -> tuple[tuple[int, ...], str]: import re m = re.match("^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT) diff --git a/pyproject.toml b/pyproject.toml index d971ae20..0755deb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,14 +6,14 @@ requires = [ [project] name = "arraycontext" -version = "2021.1" +version = "2024.0" description = "Choose your favorite numpy-workalike" readme = "README.rst" license = { text = "MIT" } authors = [ { name = "Andreas Kloeckner", email = "inform@tiker.net" }, ] -requires-python = ">=3.8" +requires-python = ">=3.10" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -33,9 +33,6 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", - - # for TypeAlias - "typing-extensions>=4; python_version<'3.10'", ] [project.optional-dependencies] @@ -122,8 +119,7 @@ known-local-folder = [ lines-after-imports = 2 [tool.mypy] -# TODO: unpin jax version on CI when this gets bumped to 3.10 -python_version = "3.8" +python_version = "3.10" warn_unused_ignores = true # TODO: enable this # check_untyped_defs = true diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 7bea0dc4..47d8e941 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -255,7 +255,7 @@ def array_context(self): def _get_test_containers(actx, ambient_dim=2, shapes=50_000): from numbers import Number - if isinstance(shapes, (Number, tuple)): + if isinstance(shapes, Number | tuple): shapes = [shapes] x = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) @@ -796,7 +796,7 @@ def test_container_map_on_device_scalar(actx_factory): rec_map_reduce_array_container, ) - for size, ary in zip(expected_sizes, arys[:-1]): + for size, ary in zip(expected_sizes, arys[:-1], strict=True): result = map_array_container(lambda x: x, ary) assert actx.to_numpy(actx.np.array_equal(result, ary)) result = rec_map_array_container(lambda x: x, ary) @@ -827,7 +827,8 @@ def _check_allclose(f, arg1, arg2, atol=2.0e-14): subarray for _, subarray in arg1_iterable] arg2_subarrays = [ subarray for _, subarray in arg2_iterable] - for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, + strict=True): _check_allclose(f, subarray1, subarray2) def func(x): @@ -880,7 +881,8 @@ def _check_allclose(f, arg1, arg2, atol=2.0e-14): subarray for _, subarray in arg1_iterable] arg2_subarrays = [ subarray for _, subarray in arg2_iterable] - for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, + strict=True): _check_allclose(f, subarray1, subarray2) def func_all_scalar(x, y): @@ -1072,7 +1074,7 @@ def test_flatten_array_container(actx_factory, shapes): # {{{ complex to real - if isinstance(shapes, (int, tuple)): + if isinstance(shapes, int | tuple): shapes = [shapes] ary = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) @@ -1556,7 +1558,7 @@ def test_to_numpy_on_frozen_arrays(actx_factory): def test_tagging(actx_factory): actx = actx_factory() - if isinstance(actx, (NumpyArrayContext, EagerJAXArrayContext)): + if isinstance(actx, NumpyArrayContext | EagerJAXArrayContext): pytest.skip(f"{type(actx)} has no tagging support") from pytools.tag import Tag