From 22f9ead841b2358f8859293e7595cb451f715c8e Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 15 Nov 2024 12:13:11 +0200 Subject: [PATCH] feat: improve dataclass container --- arraycontext/container/dataclass.py | 8 +++-- test/test_utils.py | 55 ++++++++++++++++++++++++----- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index ec4c37f4..ae9ab486 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -59,6 +59,8 @@ def dataclass_array_container(cls: type) -> type: array containers, even if they wrap one. """ + from types import GenericAlias, UnionType + assert is_dataclass(cls) def is_array_field(f: Field) -> bool: @@ -75,7 +77,8 @@ def is_array_field(f: Field) -> bool: # This is not set in stone, but mostly driven by current usage! origin = get_origin(f.type) - if origin is Union: + # NOTE: `UnionType` is returned when using `Type1 | Type2` + if origin in (Union, UnionType): if all(is_array_type(arg) for arg in get_args(f.type)): return True else: @@ -94,13 +97,14 @@ def is_array_field(f: Field) -> bool: f"Field with 'init=False' not allowed: '{f.name}'") # NOTE: + # * `GenericAlias` catches typed `list`, `tuple`, etc. # * `_BaseGenericAlias` catches `List`, `Tuple`, etc. # * `_SpecialForm` catches `Any`, `Literal`, etc. from typing import ( # type: ignore[attr-defined] _BaseGenericAlias, _SpecialForm, ) - if isinstance(f.type, _BaseGenericAlias | _SpecialForm): + if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( f"Typing annotation not supported on field '{f.name}': " diff --git a/test/test_utils.py b/test/test_utils.py index 04817d6a..db9ed825 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -49,9 +49,9 @@ def test_pt_actx_key_stringification_uniqueness(): def test_dataclass_array_container() -> None: from dataclasses import dataclass, field - from typing import Optional + from typing import Optional, Tuple # noqa: UP035 - from arraycontext import dataclass_array_container + from arraycontext import Array, dataclass_array_container # {{{ string fields @@ -60,7 +60,7 @@ class ArrayContainerWithStringTypes: x: np.ndarray y: "np.ndarray" - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="String annotation on field 'y'"): # NOTE: cannot have string annotations in container dataclass_array_container(ArrayContainerWithStringTypes) @@ -73,12 +73,32 @@ class ArrayContainerWithOptional: x: np.ndarray y: Optional[np.ndarray] - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`) dataclass_array_container(ArrayContainerWithOptional) # }}} + # {{{ type annotations + + @dataclass + class ArrayContainerWithTuple: + x: Array + y: Tuple[Array, Array] + + with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"): + dataclass_array_container(ArrayContainerWithTuple) + + @dataclass + class ArrayContainerWithTupleAlt: + x: Array + y: tuple[Array, Array] + + with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"): + dataclass_array_container(ArrayContainerWithTupleAlt) + + # }}} + # {{{ field(init=False) @dataclass @@ -87,7 +107,7 @@ class ArrayContainerWithInitFalse: y: np.ndarray = field(default_factory=lambda: np.zeros(42), init=False, repr=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Field with 'init=False' not allowed"): # NOTE: init=False fields are not allowed dataclass_array_container(ArrayContainerWithInitFalse) @@ -95,8 +115,6 @@ class ArrayContainerWithInitFalse: # {{{ device arrays - from arraycontext import Array - @dataclass class ArrayContainerWithArray: x: Array @@ -126,6 +144,13 @@ class ArrayContainerWithUnion: dataclass_array_container(ArrayContainerWithUnion) + @dataclass + class ArrayContainerWithUnionAlt: + x: np.ndarray + y: np.ndarray | Array + + dataclass_array_container(ArrayContainerWithUnionAlt) + # }}} # {{{ non-container union @@ -135,12 +160,26 @@ class ArrayContainerWithWrongUnion: x: np.ndarray y: Union[np.ndarray, float] - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): # NOTE: float is not an ArrayContainer, so y should fail dataclass_array_container(ArrayContainerWithWrongUnion) # }}} + # {{{ optional union + + @dataclass + class ArrayContainerWithOptionalUnion: + x: np.ndarray + y: np.ndarray | None + + with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): + # NOTE: None is not an ArrayContainer, so y should fail + dataclass_array_container(ArrayContainerWithWrongUnion) + + # }}} + + # }}}