diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 72e39a14..22572dc8 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -539,7 +539,8 @@ def {fname}(arg1): _format_binary_op_str(op_str, expr_arg1, expr_arg2) for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip( cls._serialize_init_arrays_code("arg1").items(), - cls._serialize_init_arrays_code("arg2").items()) + cls._serialize_init_arrays_code("arg2").items(), + strict=True) }) bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 492f0c92..ec4c37f4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -31,7 +31,7 @@ """ from dataclasses import Field, fields, is_dataclass -from typing import Tuple, Union, get_args, get_origin +from typing import Union, get_args, get_origin from arraycontext.container import is_array_container_type @@ -100,7 +100,7 @@ def is_array_field(f: Field) -> bool: _BaseGenericAlias, _SpecialForm, ) - if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)): + if isinstance(f.type, _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( f"Typing annotation not supported on field '{f.name}': " @@ -125,8 +125,8 @@ def is_array_field(f: Field) -> bool: def inject_dataclass_serialization( cls: type, - array_fields: Tuple[Field, ...], - non_array_fields: Tuple[Field, ...]) -> type: + array_fields: tuple[Field, ...], + non_array_fields: tuple[Field, ...]) -> type: """Implements :func:`~arraycontext.serialize_container` and :func:`~arraycontext.deserialize_container` for the given dataclass *cls*. diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 80f38afa..62f6354c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -165,10 +165,11 @@ def rec(*_args: Any) -> Any: for subarys in zip( iterable_template, - *[serialize_container(_args[i]) for i in container_indices[1:]] + *[serialize_container(_args[i]) for i in container_indices[1:]], + strict=True ): key = None - for i, (subkey, subary) in zip(container_indices, subarys): + for i, (subkey, subary) in zip(container_indices, subarys, strict=True): if key is None: key = subkey else: diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index bc9481e3..094e8cf2 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -187,7 +187,7 @@ def rec_equal(x, y): [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], + in zip(serialized_x, serialized_y, strict=True)], true_ary) return rec_equal(a, b) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 8517ab69..f345edc9 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -149,8 +149,8 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: [(true_ary if kx_i == ky_i else false_ary) and cast(np.ndarray, self.array_equal(x_i, y_i)) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], - true_ary) + in zip(serialized_x, serialized_y, strict=True)], + initial=true_ary) def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index ac792452..ae340ca9 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -236,7 +236,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], + in zip(serialized_x, serialized_y, strict=True)], true_ary) return rec_equal(a, b) @@ -346,7 +346,7 @@ def absolute(self, a): def where(self, criterion, then, else_): def where_inner(inner_crit, inner_then, inner_else): - if isinstance(inner_crit, (bool, np.bool_)): + if isinstance(inner_crit, bool | np.bool_): return inner_then if inner_crit else inner_else return cl_array.if_positive(inner_crit != 0, inner_then, inner_else, queue=self._array_context.queue) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index e3c830e7..1d36971c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -676,7 +676,7 @@ def preprocess_arg(name, arg): return pt.einsum(spec, *[ preprocess_arg(name, arg) - for name, arg in zip(arg_names, args) + for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged)) def clone(self): @@ -905,7 +905,7 @@ def preprocess_arg(name, arg): return pt.einsum(spec, *[ preprocess_arg(name, arg) - for name, arg in zip(arg_names, args) + for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged)) def clone(self): diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index c6508e3a..0692eb7e 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -203,7 +203,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) - in zip(serialized_x, serialized_y)], + in zip(serialized_x, serialized_y, strict=True)], true_ary) return cast(Array, rec_equal(a, b)) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index a62023b5..da717846 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -89,7 +89,7 @@ def get(c_name, nargs, naxes): from islpy import make_zero_and_vars v = make_zero_and_vars(var_names, params=size_names) domain = v[0].domain() - for vname, sname in zip(var_names, size_names): + for vname, sname in zip(var_names, size_names, strict=True): domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname]) domain_bset, = domain.get_basic_sets() diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 5cffb20d..47d8e941 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -796,7 +796,7 @@ def test_container_map_on_device_scalar(actx_factory): rec_map_reduce_array_container, ) - for size, ary in zip(expected_sizes, arys[:-1]): + for size, ary in zip(expected_sizes, arys[:-1], strict=True): result = map_array_container(lambda x: x, ary) assert actx.to_numpy(actx.np.array_equal(result, ary)) result = rec_map_array_container(lambda x: x, ary) @@ -827,7 +827,8 @@ def _check_allclose(f, arg1, arg2, atol=2.0e-14): subarray for _, subarray in arg1_iterable] arg2_subarrays = [ subarray for _, subarray in arg2_iterable] - for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, + strict=True): _check_allclose(f, subarray1, subarray2) def func(x): @@ -880,7 +881,8 @@ def _check_allclose(f, arg1, arg2, atol=2.0e-14): subarray for _, subarray in arg1_iterable] arg2_subarrays = [ subarray for _, subarray in arg2_iterable] - for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, + strict=True): _check_allclose(f, subarray1, subarray2) def func_all_scalar(x, y):