Skip to content

Commit

Permalink
(Array API support): Add __array_namespace_info__ and device (#101)
Browse files Browse the repository at this point in the history
Co-authored-by: Christian Bourjau <[email protected]>
  • Loading branch information
adityagoel4512 and cbourjau authored Jan 31, 2025
1 parent d9802b0 commit 4d70e5a
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 39 deletions.
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "array-api-tests"]
path = api-coverage-tests
[submodule "api-coverage-tests"]
path = array-api-tests
url = [email protected]:data-apis/array-api-tests.git
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
language: system
types: [python]
require_serial: true
exclude: ^(tests|api-coverage-tests)/
exclude: ^(tests|array-api-tests)/
# prettier
- id: prettier
name: prettier
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@ Changelog
0.10.0 (unreleased)
-------------------

**Breaking change**
- Removed the deprecated :func:`ndonnx.promote_nullable` function. Use :func:`ndonnx.additional.make_nullable` instead.

**Array API compliance**

- ndonnx now supports the :func:`ndonnx.__array_namespace_info__` function from the Array API standard.
- Arrays now expose the :meth:`ndonnx.Array.device` property to improve Array API compatibility. Note that serializing an ONNX model inherently postpones device placement decisions to the runtime so currently one abstract device is supported.


0.9.3 (2024-10-25)
------------------
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pytest tests -n auto

It has a couple of key features:

- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like `NumPy`, `JAX` and now `ndonnx`.
- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like NumPy, JAX and now ndonnx.

```python
import numpy as np
Expand Down Expand Up @@ -93,7 +93,7 @@ In the future we will be enabling a stable API for an extensible data type syste

## Array API coverage

Array API compatibility is tracked in `api-coverage-tests`. Missing coverage is tracked in the `skips.txt` file. Contributions are welcome!
Array API compatibility is tracked in `array-api-tests`. Missing coverage is tracked in the `skips.txt` file. Contributions are welcome!

Summary(1119 total):

Expand Down
1 change: 0 additions & 1 deletion api-coverage-tests
Submodule api-coverage-tests deleted from 4caff2
1 change: 1 addition & 0 deletions array-api-tests
Submodule array-api-tests added at dad773
2 changes: 2 additions & 0 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
take,
UnsupportedOperationError,
)
from ._info import __array_namespace_info__
from ._constants import (
e,
inf,
Expand All @@ -176,6 +177,7 @@


__all__ = [
"__array_namespace_info__",
"Array",
"array",
"from_spox_var",
Expand Down
33 changes: 31 additions & 2 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import typing
from collections.abc import Callable
from typing import Union
from typing import Any, Union

import numpy as np
import spox.opset.ai.onnx.v19 as op
Expand Down Expand Up @@ -254,6 +254,19 @@ def shape(self) -> tuple[int | None, ...]:
else:
return static_shape(self)

@property
def device(self):
return device

def to_device(
self, device: _Device, /, *, stream: int | Any | None = None
) -> Array:
if device != self.device:
raise ValueError("Cannot move Array to a different device")
if stream is not None:
raise ValueError("The 'stream' parameter is not supported in ndonnx.")
return self.copy()

@property
def values(self) -> Array:
"""Accessor for data in a ``Array`` with nullable datatype."""
Expand Down Expand Up @@ -579,7 +592,23 @@ def any(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array:
return ndx.any(self, axis=axis, keepdims=False)


class _Device:
# We would rather not give users the impression that their arrays
# are tied to a specific device when serializing an ONNX graph as
# such a concept does not exist in the ONNX standard.

def __str__(self):
return "ndonnx device"

def __eq__(self, other):
return type(other) is _Device


device = _Device()


__all__ = [
"Array",
"array",
"device",
]
6 changes: 5 additions & 1 deletion ndonnx/_data_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -28,6 +28,8 @@
uint32,
uint64,
utf8,
canonical_name,
kinds,
)
from .classes import (
Floating,
Expand Down Expand Up @@ -145,4 +147,6 @@ def into_nullable(dtype: StructType | CoreType) -> NullableCore:
"CastMixin",
"CastError",
"Dtype",
"canonical_name",
"kinds",
]
57 changes: 56 additions & 1 deletion ndonnx/_data_types/aliases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ndonnx import CoreType

from .classes import (
Boolean,
Expand Down Expand Up @@ -55,3 +61,52 @@
nuint32: NUInt32 = NUInt32()
nuint64: NUInt64 = NUInt64()
nutf8: NUtf8 = NUtf8()


_canonical_names = {
bool: "bool",
float32: "float32",
float64: "float64",
int8: "int8",
int16: "int16",
int32: "int32",
int64: "int64",
uint8: "uint8",
uint16: "uint16",
uint32: "uint32",
uint64: "uint64",
utf8: "utf8",
}


def canonical_name(dtype: CoreType) -> str:
"""Return the canonical name of the data type."""
if dtype in _canonical_names:
return _canonical_names[dtype]
else:
raise ValueError(f"Unknown data type: {dtype}")


_kinds = {
bool: ("bool",),
int8: ("signed integer", "integer", "numeric"),
int16: ("signed integer", "integer", "numeric"),
int32: ("signed integer", "integer", "numeric"),
int64: ("signed integer", "integer", "numeric"),
uint8: ("unsigned integer", "integer", "numeric"),
uint16: ("unsigned integer", "integer", "numeric"),
uint32: ("unsigned integer", "integer", "numeric"),
uint64: ("unsigned integer", "integer", "numeric"),
float32: ("floating", "numeric"),
float64: ("floating", "numeric"),
}


def kinds(dtype: CoreType) -> tuple[str, ...]:
"""Return the kinds of the data type."""
if dtype in _kinds:
return _kinds[dtype]
elif dtype == utf8:
raise ValueError(f"We don't yet define a kind for {dtype}")
else:
raise ValueError(f"Unknown data type: {dtype}")
15 changes: 1 addition & 14 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,7 @@ def iinfo(dtype):

def isdtype(dtype, kind) -> bool:
if isinstance(kind, str):
if kind == "bool":
return dtype == dtypes.bool
elif kind == "signed integer":
return dtype in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
elif kind == "unsigned integer":
return dtype in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
elif kind == "integral":
return isinstance(dtype, dtypes.Integral)
elif kind == "real floating":
return isinstance(dtype, dtypes.Floating)
elif kind == "complex floating":
raise ValueError("'complex floating' is not supported")
elif kind == "numeric":
return isinstance(dtype, dtypes.Numerical)
return kind in dtypes.kinds(dtype)
elif isinstance(kind, dtypes.CoreType):
return dtype == kind
elif isinstance(kind, tuple):
Expand Down
62 changes: 62 additions & 0 deletions ndonnx/_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import ndonnx as ndx
from ndonnx._array import _Device, device
from ndonnx._data_types import canonical_name


class ArrayNamespaceInfo:
"""Namespace metadata for the Array API standard."""

_all_array_api_types = [
ndx.bool,
ndx.float32,
ndx.float64,
ndx.int8,
ndx.int16,
ndx.int32,
ndx.int64,
ndx.uint8,
ndx.uint16,
ndx.uint32,
ndx.uint64,
]

def capabilities(self) -> dict[str, bool]:
return {
"boolean indexing": True,
"data-dependent shapes": True,
}

def default_device(self) -> _Device:
return device

def devices(self) -> list[_Device]:
return [device]

def dtypes(
self, *, device=None, kind: str | tuple[str, ...] | None = None
) -> dict[str, ndx.CoreType]:
out: dict[str, ndx.CoreType] = {}
for dtype in self._all_array_api_types:
if kind is None or ndx.isdtype(dtype, kind):
out[canonical_name(dtype)] = dtype
return out

def default_dtypes(
self,
*,
device=None,
) -> dict[str, ndx.CoreType]:
return {
"real floating": ndx.float64,
"integral": ndx.int64,
"indexing": ndx.int64,
}


def __array_namespace_info__() -> ArrayNamespaceInfo: # noqa: N807
return ArrayNamespaceInfo()
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ test = "pytest"
test-coverage = "pytest --cov=ndonnx --cov-report=xml --cov-report=term-missing"

[feature.test.tasks.arrayapitests]
cmd = "pytest api-coverage-tests/array_api_tests/ -v -rfX --json-report --json-report-file=api-coverage-tests.json -n auto --disable-deadline --disable-extension linalg --skips-file=skips.txt --xfails-file=xfails.txt"
cmd = "python -m pytest array-api-tests/array_api_tests/ -v -k 'not meta_tests' --disable-extension linalg --disable-deadline --skips-file=skips.txt --xfails-file=xfails.txt --json-report --json-report-file=api-coverage-tests.json -nauto"
[feature.test.tasks.arrayapitests.env]
ARRAY_API_TESTS_MODULE = "ndonnx"
ARRAY_API_TESTS_VERSION = "2023.12"
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ indent-style = "space"
python_version = '3.10'
no_implicit_optional = true
check_untyped_defs = true
exclude = ["api-coverage-tests", "tests"]
exclude = ["array-api-tests", "tests"]

[[tool.mypy.overrides]]
module = ["onnxruntime"]
ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = "--ignore=api-coverage-tests"
addopts = "--ignore=array-api-tests"
filterwarnings = ["ignore:.*google.protobuf.pyext.*:DeprecationWarning"]

[tool.typos.default]
Expand Down
12 changes: 0 additions & 12 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_data_type_functions.py::test_can_cast
array_api_tests/test_data_type_functions.py::test_isdtype
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
array_api_tests/test_has_names.py::test_has_names[array_method-__complex__]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
array_api_tests/test_has_names.py::test_has_names[creation-meshgrid]
array_api_tests/test_has_names.py::test_has_names[elementwise-conj]
Expand All @@ -36,7 +34,6 @@ array_api_tests/test_has_names.py::test_has_names[fft-irfftn]
array_api_tests/test_has_names.py::test_has_names[fft-rfft]
array_api_tests/test_has_names.py::test_has_names[fft-rfftfreq]
array_api_tests/test_has_names.py::test_has_names[fft-rfftn]
array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__]
array_api_tests/test_has_names.py::test_has_names[linalg-cholesky]
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
array_api_tests/test_has_names.py::test_has_names[linalg-det]
Expand Down Expand Up @@ -65,8 +62,6 @@ array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis]
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_inspection_functions.py::test_array_namespace_info
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_moveaxis
Expand Down Expand Up @@ -101,8 +96,6 @@ array_api_tests/test_set_functions.py::test_unique_values
array_api_tests/test_signatures.py::test_array_method_signature[__complex__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_func_signature[conj]
array_api_tests/test_signatures.py::test_func_signature[copysign]
Expand All @@ -118,11 +111,6 @@ array_api_tests/test_signatures.py::test_func_signature[signbit]
array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[vecdot]
array_api_tests/test_signatures.py::test_info_func_signature[capabilities]
array_api_tests/test_signatures.py::test_info_func_signature[default_device]
array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes]
array_api_tests/test_signatures.py::test_info_func_signature[devices]
array_api_tests/test_signatures.py::test_info_func_signature[dtypes]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_special_cases.py::test_binary[copysign(x1_i is NaN and x2_i < 0) -> NaN]
array_api_tests/test_special_cases.py::test_binary[copysign(x1_i is NaN and x2_i > 0) -> NaN]
Expand Down

0 comments on commit 4d70e5a

Please sign in to comment.