Skip to content

Commit

Permalink
feat: improve dataclass container
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 15, 2024
1 parent dee0ca4 commit 22f9ead
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
8 changes: 6 additions & 2 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def dataclass_array_container(cls: type) -> type:
array containers, even if they wrap one.
"""

from types import GenericAlias, UnionType

assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
Expand All @@ -75,7 +77,8 @@ def is_array_field(f: Field) -> bool:
# This is not set in stone, but mostly driven by current usage!

origin = get_origin(f.type)
if origin is Union:
# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(f.type)):
return True
else:
Expand All @@ -94,13 +97,14 @@ def is_array_field(f: Field) -> bool:
f"Field with 'init=False' not allowed: '{f.name}'")

# NOTE:
# * `GenericAlias` catches typed `list`, `tuple`, etc.
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Typing annotation not supported on field '{f.name}': "
Expand Down
55 changes: 47 additions & 8 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def test_pt_actx_key_stringification_uniqueness():

def test_dataclass_array_container() -> None:
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Tuple # noqa: UP035

from arraycontext import dataclass_array_container
from arraycontext import Array, dataclass_array_container

# {{{ string fields

Expand All @@ -60,7 +60,7 @@ class ArrayContainerWithStringTypes:
x: np.ndarray
y: "np.ndarray"

with pytest.raises(TypeError):
with pytest.raises(TypeError, match="String annotation on field 'y'"):
# NOTE: cannot have string annotations in container
dataclass_array_container(ArrayContainerWithStringTypes)

Expand All @@ -73,12 +73,32 @@ class ArrayContainerWithOptional:
x: np.ndarray
y: Optional[np.ndarray]

with pytest.raises(TypeError):
with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
# NOTE: cannot have wrapped annotations (here by `Optional`)
dataclass_array_container(ArrayContainerWithOptional)

# }}}

# {{{ type annotations

@dataclass
class ArrayContainerWithTuple:
x: Array
y: Tuple[Array, Array]

with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"):
dataclass_array_container(ArrayContainerWithTuple)

@dataclass
class ArrayContainerWithTupleAlt:
x: Array
y: tuple[Array, Array]

with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"):
dataclass_array_container(ArrayContainerWithTupleAlt)

# }}}

# {{{ field(init=False)

@dataclass
Expand All @@ -87,16 +107,14 @@ class ArrayContainerWithInitFalse:
y: np.ndarray = field(default_factory=lambda: np.zeros(42),
init=False, repr=False)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Field with 'init=False' not allowed"):
# NOTE: init=False fields are not allowed
dataclass_array_container(ArrayContainerWithInitFalse)

# }}}

# {{{ device arrays

from arraycontext import Array

@dataclass
class ArrayContainerWithArray:
x: Array
Expand Down Expand Up @@ -126,6 +144,13 @@ class ArrayContainerWithUnion:

dataclass_array_container(ArrayContainerWithUnion)

@dataclass
class ArrayContainerWithUnionAlt:
x: np.ndarray
y: np.ndarray | Array

dataclass_array_container(ArrayContainerWithUnionAlt)

# }}}

# {{{ non-container union
Expand All @@ -135,12 +160,26 @@ class ArrayContainerWithWrongUnion:
x: np.ndarray
y: Union[np.ndarray, float]

with pytest.raises(TypeError):
with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
# NOTE: float is not an ArrayContainer, so y should fail
dataclass_array_container(ArrayContainerWithWrongUnion)

# }}}

# {{{ optional union

@dataclass
class ArrayContainerWithOptionalUnion:
x: np.ndarray
y: np.ndarray | None

with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
# NOTE: None is not an ArrayContainer, so y should fail
dataclass_array_container(ArrayContainerWithWrongUnion)

# }}}


# }}}


Expand Down

0 comments on commit 22f9ead

Please sign in to comment.