Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Array API support): Add __array_namespace_info__ and device #101

Merged
merged 11 commits into from
Jan 31, 2025
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "array-api-tests"]
path = api-coverage-tests
url = [email protected]:data-apis/array-api-tests.git
url = [email protected]:adityagoel4512/array-api-tests.git
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
branch = skip-unused-kinds-in-default
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@ Changelog
0.10.0 (unreleased)
-------------------

**Breaking change**
- Removed the deprecated :func:`ndonnx.promote_nullable` function. Use :func:`ndonnx.additional.make_nullable` instead.

**Array API compliance**

- ndonnx now supports the :func:`ndonnx.__array_namespace_info__` function from the Array API standard.
- Arrays now expose the :meth:`ndonnx.Array.device` property to improve Array API compatibility. Note that serializing an ONNX model inherently postpones device placement decisions to the runtime so currently one abstract device is supported.


0.9.3 (2024-10-25)
------------------
Expand Down
2 changes: 2 additions & 0 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
take,
UnsupportedOperationError,
)
from ._info import __array_namespace_info__
from ._constants import (
e,
inf,
Expand All @@ -176,6 +177,7 @@


__all__ = [
"__array_namespace_info__",
"Array",
"array",
"from_spox_var",
Expand Down
33 changes: 31 additions & 2 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import typing
from collections.abc import Callable
from typing import Union
from typing import Any, Union

import numpy as np
import spox.opset.ai.onnx.v19 as op
Expand Down Expand Up @@ -254,6 +254,19 @@ def shape(self) -> tuple[int | None, ...]:
else:
return static_shape(self)

@property
def device(self):
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
return device

def to_device(
self, device: _Device, /, *, stream: int | Any | None = None
) -> Array:
if device is not device:
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Cannot move Array to a different device")
if stream is not None:
raise ValueError("The 'stream' parameter is not supported in ndonnx.")
return self.copy()

@property
def values(self) -> Array:
"""Accessor for data in a ``Array`` with nullable datatype."""
Expand Down Expand Up @@ -579,7 +592,23 @@ def any(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array:
return ndx.any(self, axis=axis, keepdims=False)


class _Device:
# We would rather not give users the impression that their arrays
# are tied to a specific device when serializing an ONNX graph as
# such a concept does not exist in the ONNX standard.

def __str__(self):
return "ndonnx device"

def __eq__(self, other):
return type(other) is _Device


device = _Device()


__all__ = [
"Array",
"array",
"device",
]
6 changes: 5 additions & 1 deletion ndonnx/_data_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -28,6 +28,8 @@
uint32,
uint64,
utf8,
canonical_name,
kinds,
)
from .classes import (
Floating,
Expand Down Expand Up @@ -145,4 +147,6 @@ def into_nullable(dtype: StructType | CoreType) -> NullableCore:
"CastMixin",
"CastError",
"Dtype",
"canonical_name",
"kinds",
]
57 changes: 56 additions & 1 deletion ndonnx/_data_types/aliases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ndonnx import CoreType

from .classes import (
Boolean,
Expand Down Expand Up @@ -55,3 +61,52 @@
nuint32: NUInt32 = NUInt32()
nuint64: NUInt64 = NUInt64()
nutf8: NUtf8 = NUtf8()


_canonical_names = {
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
bool: "bool",
float32: "float32",
float64: "float64",
int8: "int8",
int16: "int16",
int32: "int32",
int64: "int64",
uint8: "uint8",
uint16: "uint16",
uint32: "uint32",
uint64: "uint64",
utf8: "utf8",
}


def canonical_name(dtype: CoreType) -> str:
"""Return the canonical name of the data type."""
if dtype in _canonical_names:
return _canonical_names[dtype]
else:
raise ValueError(f"Unknown data type: {dtype}")


_kinds = {
bool: ("bool",),
int8: ("signed integer", "integer", "numeric"),
int16: ("signed integer", "integer", "numeric"),
int32: ("signed integer", "integer", "numeric"),
int64: ("signed integer", "integer", "numeric"),
uint8: ("unsigned integer", "integer", "numeric"),
uint16: ("unsigned integer", "integer", "numeric"),
uint32: ("unsigned integer", "integer", "numeric"),
uint64: ("unsigned integer", "integer", "numeric"),
float32: ("floating", "numeric"),
float64: ("floating", "numeric"),
}


def kinds(dtype: CoreType) -> tuple[str, ...]:
"""Return the kinds of the data type."""
if dtype in _kinds:
return _kinds[dtype]
elif dtype == utf8:
raise ValueError(f"We don't yet define a kind for {dtype}")
else:
raise ValueError(f"Unknown data type: {dtype}")
15 changes: 1 addition & 14 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,7 @@ def iinfo(dtype):

def isdtype(dtype, kind) -> bool:
if isinstance(kind, str):
if kind == "bool":
return dtype == dtypes.bool
elif kind == "signed integer":
return dtype in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
elif kind == "unsigned integer":
return dtype in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
elif kind == "integral":
return isinstance(dtype, dtypes.Integral)
elif kind == "real floating":
return isinstance(dtype, dtypes.Floating)
elif kind == "complex floating":
raise ValueError("'complex floating' is not supported")
elif kind == "numeric":
return isinstance(dtype, dtypes.Numerical)
return kind in dtypes.kinds(dtype)
elif isinstance(kind, dtypes.CoreType):
return dtype == kind
elif isinstance(kind, tuple):
Expand Down
62 changes: 62 additions & 0 deletions ndonnx/_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import ndonnx as ndx
from ndonnx._array import device
from ndonnx._data_types import canonical_name


class ArrayNamespaceInfo:
"""Namespace metadata for the Array API standard."""

_all_array_api_types = [
ndx.bool,
ndx.float32,
ndx.float64,
ndx.int8,
ndx.int16,
ndx.int32,
ndx.int64,
ndx.uint8,
ndx.uint16,
ndx.uint32,
ndx.uint64,
]

def capabilities(self) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def capabilities(self) -> dict:
def capabilities(self) -> dict[str, bool]:

return {
"boolean indexing": True,
"data-dependent shapes": True,
}

def default_device(self):
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
return device

def devices(self) -> list:
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
return [device]

def dtypes(
self, *, device=None, kind: str | tuple[str, ...] | None = None
) -> dict[str, ndx.CoreType]:
out: dict[str, ndx.CoreType] = {}
for dtype in self._all_array_api_types:
if kind is None or ndx.isdtype(dtype, kind):
out[canonical_name(dtype)] = dtype
return out

def default_dtypes(
self,
*,
device=None,
) -> dict[str, ndx.CoreType]:
return {
"real floating": ndx.float64,
"integral": ndx.int64,
"indexing": ndx.int64,
}


def __array_namespace_info__() -> ArrayNamespaceInfo: # noqa: N807
return ArrayNamespaceInfo()
12 changes: 0 additions & 12 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_data_type_functions.py::test_can_cast
array_api_tests/test_data_type_functions.py::test_isdtype
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
array_api_tests/test_has_names.py::test_has_names[array_method-__complex__]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
array_api_tests/test_has_names.py::test_has_names[creation-meshgrid]
array_api_tests/test_has_names.py::test_has_names[elementwise-conj]
Expand All @@ -36,7 +34,6 @@ array_api_tests/test_has_names.py::test_has_names[fft-irfftn]
array_api_tests/test_has_names.py::test_has_names[fft-rfft]
array_api_tests/test_has_names.py::test_has_names[fft-rfftfreq]
array_api_tests/test_has_names.py::test_has_names[fft-rfftn]
array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__]
array_api_tests/test_has_names.py::test_has_names[linalg-cholesky]
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
array_api_tests/test_has_names.py::test_has_names[linalg-det]
Expand Down Expand Up @@ -65,8 +62,6 @@ array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis]
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_inspection_functions.py::test_array_namespace_info
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_moveaxis
Expand Down Expand Up @@ -101,8 +96,6 @@ array_api_tests/test_set_functions.py::test_unique_values
array_api_tests/test_signatures.py::test_array_method_signature[__complex__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_func_signature[conj]
array_api_tests/test_signatures.py::test_func_signature[copysign]
Expand All @@ -118,11 +111,6 @@ array_api_tests/test_signatures.py::test_func_signature[signbit]
array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[vecdot]
array_api_tests/test_signatures.py::test_info_func_signature[capabilities]
array_api_tests/test_signatures.py::test_info_func_signature[default_device]
array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes]
array_api_tests/test_signatures.py::test_info_func_signature[devices]
array_api_tests/test_signatures.py::test_info_func_signature[dtypes]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_special_cases.py::test_binary[copysign(x1_i is NaN and x2_i < 0) -> NaN]
array_api_tests/test_special_cases.py::test_binary[copysign(x1_i is NaN and x2_i > 0) -> NaN]
Expand Down