Skip to content

Commit

Permalink
feat: add DLPack support (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
LDeakin authored Mar 3, 2025
1 parent e364702 commit bf041e7
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `ArrayBytesFixedDisjointView[CreateError]`
- Add support for data type extensions with `zarrs_data_type` 0.2.0
- Add `custom_data_type_fixed_size` and `custom_data_type_variable_size` examples
- Add `[Async]ArrayDlPackExt` traits that add methods to `Array` for `DLPack` tensor interop
- Gated by the `dlpack` feature

### Changed
- **Breaking**: change `ArraySubset::inbounds` to take another subset rather than a shape
Expand Down Expand Up @@ -40,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Breaking**: `StorageTransformerPlugin` now uses a `Plugin`
- Add `DataTypeExtension` variant to `CodecError`
- `ArrayCreateError::DataTypeCreateError` now uses a `PluginCreateError` internally
- **Breaking**: `ArrayError` is now marked as non-exhaustive

### Fixed
- Fixed reserving one more element than necessary when retrieving `string` or `bytes` array elements
Expand Down
2 changes: 2 additions & 0 deletions zarrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ transpose = ["dep:ndarray"] # Enable the transpose codec
zfp = ["dep:zfp-sys"] # Enable the experimental zfp codec
zstd = ["dep:zstd"] # Enable the zstd codec
ndarray = ["dep:ndarray"] # Adds ndarray utility functions to Array
dlpack =["dep:dlpark"] # Adds dlpack utility functions to Array
async = ["dep:async-trait", "dep:futures", "zarrs_storage/async"] # Enable experimental async API

[lints]
Expand Down Expand Up @@ -74,6 +75,7 @@ zarrs_plugin = { workspace = true }
zarrs_storage = { workspace = true }
zfp-sys = {version = "0.3.0", features = ["static"], optional = true }
zstd = { version = "0.13.1", features = ["zstdmt"], optional = true }
dlpark = { version = "0.4.1", features = ["half"], optional = true }

[dependencies.num-complex]
version = "0.4.3"
Expand Down
12 changes: 10 additions & 2 deletions zarrs/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ mod element;
pub mod storage_transformer;
pub use crate::data_type; // re-export for zarrs < 0.20 compat

#[cfg(feature = "dlpack")]
mod array_dlpack_ext;
#[cfg(feature = "sharding")]
mod array_sharded_ext;
#[cfg(feature = "sharding")]
Expand Down Expand Up @@ -88,10 +90,15 @@ pub use chunk_cache::{
chunk_cache_lru::*, ChunkCache, ChunkCacheType, ChunkCacheTypeDecoded, ChunkCacheTypeEncoded,
};

#[cfg(feature = "dlpack")]
pub use array_dlpack_ext::{
ArrayDlPackExt, ArrayDlPackExtError, AsyncArrayDlPackExt, RawBytesDlPack,
};
#[cfg(feature = "sharding")]
pub use array_sharded_ext::ArrayShardedExt;
#[cfg(feature = "sharding")]
pub use array_sync_sharded_readable_ext::{ArrayShardedReadableExt, ArrayShardedReadableExtCache};

use zarrs_metadata::v3::UnsupportedAdditionalFieldError;
// TODO: Add AsyncArrayShardedReadableExt and AsyncArrayShardedReadableExtCache

Expand Down Expand Up @@ -188,8 +195,9 @@ pub fn chunk_shape_to_array_shape(chunk_shape: &[std::num::NonZeroU64]) -> Array
/// - **Experimental**: `async_` prefix variants can be used with async stores (requires `async` feature).
///
/// Additional methods are offered by extension traits:
/// - [`ArrayShardedExt`] and [`ArrayShardedReadableExt`]: see [Reading Sharded Arrays](#reading-sharded-arrays)
/// - [`ArrayChunkCacheExt`]: see [Chunk Caching](#chunk-caching)
/// - [`ArrayShardedExt`] and [`ArrayShardedReadableExt`]: see [Reading Sharded Arrays](#reading-sharded-arrays).
/// - [`ArrayChunkCacheExt`]: see [Chunk Caching](#chunk-caching).
/// - [`[Async]ArrayDlPackExt`](ArrayDlPackExt): methods for [`DLPack`](https://arrow.apache.org/docs/python/dlpack.html) tensor interop.
///
/// ### Chunks and Array Subsets
/// Several convenience methods are available for querying the underlying chunk grid:
Expand Down
132 changes: 132 additions & 0 deletions zarrs/src/array/array_dlpack_ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use std::{ffi::c_void, sync::Arc};

use derive_more::Display;
use dlpark::{ffi::Device, ShapeAndStrides, ToTensor};
use thiserror::Error;
use zarrs_data_type::DataType;

use super::{ChunkRepresentation, RawBytes};

mod array_dlpack_ext_async;
mod array_dlpack_ext_sync;

pub use array_dlpack_ext_async::AsyncArrayDlPackExt;
pub use array_dlpack_ext_sync::ArrayDlPackExt;

/// [`RawBytes`] for use in a [`dlpark::ManagerCtx`].
pub struct RawBytesDlPack {
bytes: Arc<RawBytes<'static>>,
dtype: dlpark::ffi::DataType,
shape: Vec<i64>,
}

/// Errors related to [`[Async]ArrayDlPackExt`](ArrayDlPackExt) methods.
#[derive(Clone, Debug, Error, Display)]
#[non_exhaustive]
pub enum ArrayDlPackExtError {
/// The Zarr data type is not supported by `DLPack`.
UnsupportedDataType,
}

impl RawBytesDlPack {
/// Create a new [`RawBytesDlPack`].
///
/// # Errors
/// Returns [`ArrayDlPackExtError::UnsupportedDataType`] if the data type is not supported.
///
/// # Panics
/// Panics if an element in the shape cannot be encoded in a `i64`.
pub fn new(
bytes: Arc<RawBytes<'static>>,
representation: &ChunkRepresentation,
) -> Result<Self, ArrayDlPackExtError> {
let dtype = match representation.data_type() {
DataType::Bool => dlpark::ffi::DataType::BOOL,
DataType::Int8 => dlpark::ffi::DataType::I8,
DataType::Int16 => dlpark::ffi::DataType::I16,
DataType::Int32 => dlpark::ffi::DataType::I32,
DataType::Int64 => dlpark::ffi::DataType::I64,
DataType::UInt8 => dlpark::ffi::DataType::U8,
DataType::UInt16 => dlpark::ffi::DataType::U16,
DataType::UInt32 => dlpark::ffi::DataType::U32,
DataType::UInt64 => dlpark::ffi::DataType::U64,
DataType::Float16 => dlpark::ffi::DataType::F16,
DataType::Float32 => dlpark::ffi::DataType::F32,
DataType::Float64 => dlpark::ffi::DataType::F64,
DataType::BFloat16 => dlpark::ffi::DataType::BF16,
// TODO: Support extension data types with fallback?
_ => Err(ArrayDlPackExtError::UnsupportedDataType)?,
};
let shape = representation
.shape()
.iter()
.map(|s| i64::try_from(s.get()).unwrap())
.collect();
Ok(Self {
bytes,
dtype,
shape,
})
}
}

impl ToTensor for RawBytesDlPack {
fn data_ptr(&self) -> *mut c_void {
self.bytes.as_ptr().cast::<c_void>().cast_mut()
}

fn byte_offset(&self) -> u64 {
0
}

fn device(&self) -> Device {
Device::CPU
}

fn dtype(&self) -> dlpark::ffi::DataType {
self.dtype
}

fn shape_and_strides(&self) -> ShapeAndStrides {
ShapeAndStrides::new_contiguous(&self.shape)
}
}

#[cfg(test)]
mod tests {
use dlpark::{IntoDLPack, ManagedTensor};
use zarrs_data_type::{DataType, FillValue};
use zarrs_storage::store::MemoryStore;

use crate::{
array::{codec::CodecOptions, ArrayBuilder, ArrayDlPackExt},
array_subset::ArraySubset,
};

#[test]
fn array_dlpack_ext_sync() {
let store = MemoryStore::new();
let array = ArrayBuilder::new(
vec![4, 4],
DataType::Float32,
vec![2, 2].try_into().unwrap(),
FillValue::from(-1.0f32),
)
.build(store.into(), "/")
.unwrap();
array
.store_chunk_elements::<f32>(&[0, 0], &[0.0, 1.0, 2.0, 3.0])
.unwrap();
let tensor = array
.retrieve_chunks_dlpack(
&ArraySubset::new_with_shape(vec![1, 2]),
&CodecOptions::default(),
)
.unwrap();

assert_eq!(
ManagedTensor::new(tensor.into_dlpack()).as_slice::<f32>(),
&[0.0, 1.0, -1.0, -1.0, 2.0, 3.0, -1.0, -1.0]
);
}
}
161 changes: 161 additions & 0 deletions zarrs/src/array/array_dlpack_ext/array_dlpack_ext_async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use std::{num::NonZeroU64, sync::Arc};

use dlpark::ManagerCtx;
use zarrs_storage::AsyncReadableStorageTraits;

use crate::array::{codec::CodecOptions, Array, ArrayError, ChunkRepresentation};
use crate::array_subset::ArraySubset;

use super::RawBytesDlPack;

#[cfg(doc)]
use super::ArrayDlPackExtError;

/// An async [`Array`] extension trait with methods that return `DLPack` managed tensors.
#[cfg_attr(feature = "async", async_trait::async_trait)]
pub trait AsyncArrayDlPackExt<TStorage: ?Sized + AsyncReadableStorageTraits + 'static>:
private::Sealed
{
/// Read and decode the `array_subset` of array into a `DLPack` tensor.
///
/// See [`Array::retrieve_array_subset_opt`].
///
/// # Errors
/// Returns a [`super::ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor.
/// Otherwise returns standard [`Array::retrieve_array_subset_opt`] errors.
async fn retrieve_array_subset_dlpack(
&self,
array_subset: &ArraySubset,
options: &CodecOptions,
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError>;

/// Read and decode the chunk at `chunk_indices` into a `DLPack` tensor if it exists.
///
/// See [`Array::retrieve_chunk_if_exists_opt`].
///
/// # Errors
/// Returns a [`ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor.
/// Otherwise returns standard [`Array::retrieve_chunk_if_exists_opt`] errors.
async fn retrieve_chunk_if_exists_dlpack(
&self,
chunk_indices: &[u64],
options: &CodecOptions,
) -> Result<Option<ManagerCtx<RawBytesDlPack>>, ArrayError>;

/// Read and decode the chunk at `chunk_indices` into a `DLPack` tensor.
///
/// See [`Array::retrieve_chunk_opt`].
///
/// # Errors
/// Returns a [`ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor.
/// Otherwise returns standard [`Array::retrieve_chunk_opt`] errors.
async fn retrieve_chunk_dlpack(
&self,
chunk_indices: &[u64],
options: &CodecOptions,
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError>;

/// Read and decode the chunks at `chunks` into a `DLPack` tensor.
///
/// See [`Array::retrieve_chunks_opt`].
///
/// # Errors
/// Returns a [`ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor.
/// Otherwise returns standard [`Array::retrieve_chunks_opt`] errors.
async fn retrieve_chunks_dlpack(
&self,
chunks: &ArraySubset,
options: &CodecOptions,
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError>;
}

#[cfg_attr(feature = "async", async_trait::async_trait)]
impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> AsyncArrayDlPackExt<TStorage>
for Array<TStorage>
{
async fn retrieve_array_subset_dlpack(
&self,
array_subset: &ArraySubset,
options: &CodecOptions,
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError> {
let bytes = self
.async_retrieve_array_subset_opt(array_subset, options)
.await?
.into_owned();
let bytes = Arc::new(bytes.into_fixed()?);

let representation = unsafe {
// SAFETY: the data type and fill value are confirmed compatible
ChunkRepresentation::new_unchecked(
array_subset
.shape()
.iter()
.map(|s| NonZeroU64::new(*s))
.collect::<Option<Vec<_>>>()
.ok_or(ArrayError::InvalidArraySubset(
array_subset.clone(),
self.shape().to_vec(),
))?,
self.data_type().clone(),
self.fill_value().clone(),
)
};

Ok(ManagerCtx::new(
RawBytesDlPack::new(bytes, &representation).map_err(ArrayError::DlPackError)?,
))
}

async fn retrieve_chunk_if_exists_dlpack(
&self,
chunk_indices: &[u64],
options: &CodecOptions,
) -> Result<Option<ManagerCtx<RawBytesDlPack>>, ArrayError> {
let Some(bytes) = self
.async_retrieve_chunk_if_exists_opt(chunk_indices, options)
.await?
else {
return Ok(None);
};
let bytes = bytes.into_owned();
let bytes = Arc::new(bytes.into_fixed()?);
let representation = self.chunk_array_representation(chunk_indices)?;
Ok(Some(ManagerCtx::new(
RawBytesDlPack::new(bytes, &representation).map_err(ArrayError::DlPackError)?,
)))
}

async fn retrieve_chunk_dlpack(
&self,
chunk_indices: &[u64],
options: &CodecOptions,
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError> {
let bytes = self
.async_retrieve_chunk_opt(chunk_indices, options)
.await?
.into_owned();
let bytes = Arc::new(bytes.into_fixed()?);
let representation = self.chunk_array_representation(chunk_indices)?;
Ok(ManagerCtx::new(
RawBytesDlPack::new(bytes, &representation).map_err(ArrayError::DlPackError)?,
))
}

async fn retrieve_chunks_dlpack(
&self,
chunks: &ArraySubset,
options: &CodecOptions,
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError> {
let array_subset = self.chunks_subset(chunks)?;
self.retrieve_array_subset_dlpack(&array_subset, options)
.await
}
}

mod private {
use super::{Array, AsyncReadableStorageTraits};

pub trait Sealed {}

impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Sealed for Array<TStorage> {}
}
Loading

0 comments on commit bf041e7

Please sign in to comment.