diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index affa1da3..e50207af 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,7 +23,7 @@ """ from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from pytato.array import ( AbstractResultWithNamedArrays, @@ -71,7 +71,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.bound_arguments[name] = expr.data return make_placeholder( name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s + shape=tuple(cast(Array, self.rec(s)) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, axes=expr.axes, @@ -87,7 +87,7 @@ def map_placeholder(self, expr: Placeholder) -> Array: def _normalize_pt_expr( expr: DictOfNamedArrays - ) -> tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]: + ) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of