Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better types 2024 #290

Merged
merged 9 commits into from
Nov 29, 2024
Merged
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ on:
schedule:
- cron: '17 3 * * 0'

concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true

jobs:
typos:
name: Typos
Expand Down
11 changes: 11 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
An array context is an abstraction that helps you dispatch between multiple
implementations of :mod:`numpy`-like :math:`n`-dimensional arrays.
"""
from __future__ import annotations


__copyright__ = """
Expand Down Expand Up @@ -29,6 +30,7 @@
"""

from .container import (
ArithArrayContainer,
ArrayContainer,
ArrayContainerT,
NotAnArrayContainerError,
Expand Down Expand Up @@ -72,6 +74,10 @@
from .context import (
Array,
ArrayContext,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar,
ArrayOrArithContainerOrScalarT,
ArrayOrArithContainerT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
Expand All @@ -95,10 +101,15 @@


__all__ = (
"ArithArrayContainer",
"Array",
"ArrayContainer",
"ArrayContainerT",
"ArrayContext",
"ArrayOrArithContainer",
"ArrayOrArithContainerOrScalar",
"ArrayOrArithContainerOrScalarT",
"ArrayOrArithContainerT",
"ArrayOrContainer",
"ArrayOrContainerOrScalar",
"ArrayOrContainerOrScalarT",
Expand Down
31 changes: 28 additions & 3 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.. currentmodule:: arraycontext

.. autoclass:: ArrayContainer
.. autoclass:: ArithArrayContainer
.. class:: ArrayContainerT

A type variable with a lower bound of :class:`ArrayContainer`.
Expand Down Expand Up @@ -81,14 +82,15 @@

from collections.abc import Hashable, Sequence
from functools import singledispatch
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar

# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np
from typing_extensions import Self

from arraycontext.context import ArrayContext
from arraycontext.context import ArrayContext, ArrayOrScalar


if TYPE_CHECKING:
Expand Down Expand Up @@ -145,6 +147,29 @@ class ArrayContainer(Protocol):
# that are container-typed.


class ArithArrayContainer(ArrayContainer, Protocol):
alexfikl marked this conversation as resolved.
Show resolved Hide resolved
"""
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
"""

# This is loose and permissive, assuming that any array can be added
# to any container. The alternative would be to plaster type-ignores
# on all those uses. Achieving typing precision on what broadcasting is
# allowable seems like a huge endeavor and is likely not feasible without
# a mypy plugin. Maybe some day? -AK, November 2024

def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...


ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)


Expand Down Expand Up @@ -219,7 +244,7 @@ def is_array_container_type(cls: type) -> bool:
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]


def is_array_container(ary: Any) -> bool:
def is_array_container(ary: object) -> bool:
"""
:returns: *True* if the instance *ary* has a registered implementation of
:func:`serialize_container`.
Expand Down
56 changes: 41 additions & 15 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container
"""
from __future__ import annotations


__copyright__ = """
Expand All @@ -30,6 +31,7 @@
THE SOFTWARE.
"""

from collections.abc import Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from typing import Union, get_args, get_origin

Expand Down Expand Up @@ -57,13 +59,21 @@ def dataclass_array_container(cls: type) -> type:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.

.. note::

When type annotations are strings (e.g. because of
``from __future__ import annotations``),
this function relies on :func:`inspect.get_annotations`
(with ``eval_str=True``) to obtain type annotations. This
means that *cls* must live in a module that is importable.
"""

from types import GenericAlias, UnionType

assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
def is_array_field(f: Field, field_type: type) -> bool:
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
Expand All @@ -76,17 +86,17 @@ def is_array_field(f: Field) -> bool:
#
# This is not set in stone, but mostly driven by current usage!

origin = get_origin(f.type)
origin = get_origin(field_type)
# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(f.type)):
if all(is_array_type(arg) for arg in get_args(field_type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")

if isinstance(f.type, str):
if isinstance(field_type, str):
raise TypeError(
f"String annotation on field '{f.name}' not supported. "
"(this may be due to 'from __future__ import annotations')")
Expand All @@ -104,33 +114,49 @@ def is_array_field(f: Field) -> bool:
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm):
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Typing annotation not supported on field '{f.name}': "
f"'{f.type!r}'")
f"'{field_type!r}'")

if not isinstance(f.type, type):
if not isinstance(field_type, type):
raise TypeError(
f"Field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'")
f"'{field_type!r}'")

return is_array_type(field_type)

from inspect import get_annotations

return is_array_type(f.type)
array_fields: list[Field] = []
non_array_fields: list[Field] = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
field_type = cls_ann[field.name]
else:
field_type = field_type_or_str

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
if is_array_field(field, field_type):
array_fields.append(field)
else:
non_array_fields.append(field)

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return inject_dataclass_serialization(cls, array_fields, non_array_fields)
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)


def inject_dataclass_serialization(
def _inject_dataclass_serialization(
cls: type,
array_fields: tuple[Field, ...],
non_array_fields: tuple[Field, ...]) -> type:
array_fields: Sequence[Field],
non_array_fields: Sequence[Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.

Expand Down
Loading
Loading