Skip to content

Commit

Permalink
ruff: fix zip strict argument
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 13, 2024
1 parent fa3faeb commit 81ea941
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 19 deletions.
3 changes: 2 additions & 1 deletion arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ def {fname}(arg1):
_format_binary_op_str(op_str, expr_arg1, expr_arg2)
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
cls._serialize_init_arrays_code("arg1").items(),
cls._serialize_init_arrays_code("arg2").items())
cls._serialize_init_arrays_code("arg2").items(),
strict=True)
})
bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
Expand Down
8 changes: 4 additions & 4 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""

from dataclasses import Field, fields, is_dataclass
from typing import Tuple, Union, get_args, get_origin
from typing import Union, get_args, get_origin

from arraycontext.container import is_array_container_type

Expand Down Expand Up @@ -100,7 +100,7 @@ def is_array_field(f: Field) -> bool:
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Typing annotation not supported on field '{f.name}': "
Expand All @@ -125,8 +125,8 @@ def is_array_field(f: Field) -> bool:

def inject_dataclass_serialization(
cls: type,
array_fields: Tuple[Field, ...],
non_array_fields: Tuple[Field, ...]) -> type:
array_fields: tuple[Field, ...],
non_array_fields: tuple[Field, ...]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
Expand Down
5 changes: 3 additions & 2 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,11 @@ def rec(*_args: Any) -> Any:

for subarys in zip(
iterable_template,
*[serialize_container(_args[i]) for i in container_indices[1:]]
*[serialize_container(_args[i]) for i in container_indices[1:]],
strict=True
):
key = None
for i, (subkey, subary) in zip(container_indices, subarys):
for i, (subkey, subary) in zip(container_indices, subarys, strict=True):
if key is None:
key = subkey
else:
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def rec_equal(x, y):
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y)],
in zip(serialized_x, serialized_y, strict=True)],
true_ary)

return rec_equal(a, b)
Expand Down
4 changes: 2 additions & 2 deletions arraycontext/impl/numpy/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
[(true_ary if kx_i == ky_i else false_ary)
and cast(np.ndarray, self.array_equal(x_i, y_i))
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y)],
true_ary)
in zip(serialized_x, serialized_y, strict=True)],
initial=true_ary)

def arange(self, *args, **kwargs):
return np.arange(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y)],
in zip(serialized_x, serialized_y, strict=True)],
true_ary)

return rec_equal(a, b)
Expand Down Expand Up @@ -346,7 +346,7 @@ def absolute(self, a):

def where(self, criterion, then, else_):
def where_inner(inner_crit, inner_then, inner_else):
if isinstance(inner_crit, (bool, np.bool_)):
if isinstance(inner_crit, bool | np.bool_):
return inner_then if inner_crit else inner_else
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
queue=self._array_context.queue)
Expand Down
4 changes: 2 additions & 2 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def preprocess_arg(name, arg):

return pt.einsum(spec, *[
preprocess_arg(name, arg)
for name, arg in zip(arg_names, args)
for name, arg in zip(arg_names, args, strict=True)
]).tagged(_preprocess_array_tags(tagged))

def clone(self):
Expand Down Expand Up @@ -905,7 +905,7 @@ def preprocess_arg(name, arg):

return pt.einsum(spec, *[
preprocess_arg(name, arg)
for name, arg in zip(arg_names, args)
for name, arg in zip(arg_names, args, strict=True)
]).tagged(_preprocess_array_tags(tagged))

def clone(self):
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array:
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y)],
in zip(serialized_x, serialized_y, strict=True)],
true_ary)

return cast(Array, rec_equal(a, b))
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get(c_name, nargs, naxes):
from islpy import make_zero_and_vars
v = make_zero_and_vars(var_names, params=size_names)
domain = v[0].domain()
for vname, sname in zip(var_names, size_names):
for vname, sname in zip(var_names, size_names, strict=True):
domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])

domain_bset, = domain.get_basic_sets()
Expand Down
8 changes: 5 additions & 3 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def test_container_map_on_device_scalar(actx_factory):
rec_map_reduce_array_container,
)

for size, ary in zip(expected_sizes, arys[:-1]):
for size, ary in zip(expected_sizes, arys[:-1], strict=True):
result = map_array_container(lambda x: x, ary)
assert actx.to_numpy(actx.np.array_equal(result, ary))
result = rec_map_array_container(lambda x: x, ary)
Expand Down Expand Up @@ -827,7 +827,8 @@ def _check_allclose(f, arg1, arg2, atol=2.0e-14):
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays,
strict=True):
_check_allclose(f, subarray1, subarray2)

def func(x):
Expand Down Expand Up @@ -880,7 +881,8 @@ def _check_allclose(f, arg1, arg2, atol=2.0e-14):
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays,
strict=True):
_check_allclose(f, subarray1, subarray2)

def func_all_scalar(x, y):
Expand Down

0 comments on commit 81ea941

Please sign in to comment.