-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
467 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
use std::{ffi::c_void, sync::Arc}; | ||
|
||
use derive_more::Display; | ||
use dlpark::{ffi::Device, ShapeAndStrides, ToTensor}; | ||
use thiserror::Error; | ||
use zarrs_data_type::DataType; | ||
|
||
use super::{ChunkRepresentation, RawBytes}; | ||
|
||
mod array_dlpack_ext_async; | ||
mod array_dlpack_ext_sync; | ||
|
||
pub use array_dlpack_ext_async::AsyncArrayDlPackExt; | ||
pub use array_dlpack_ext_sync::ArrayDlPackExt; | ||
|
||
/// [`RawBytes`] for use in a [`dlpark::ManagerCtx`]. | ||
pub struct RawBytesDlPack { | ||
bytes: Arc<RawBytes<'static>>, | ||
dtype: dlpark::ffi::DataType, | ||
shape: Vec<i64>, | ||
} | ||
|
||
/// Errors related to [`[Async]ArrayDlPackExt`](ArrayDlPackExt) methods. | ||
#[derive(Clone, Debug, Error, Display)] | ||
#[non_exhaustive] | ||
pub enum ArrayDlPackExtError { | ||
/// The Zarr data type is not supported by `DLPack`. | ||
UnsupportedDataType, | ||
} | ||
|
||
impl RawBytesDlPack { | ||
/// Create a new [`RawBytesDlPack`]. | ||
/// | ||
/// # Errors | ||
/// Returns [`ArrayDlPackExtError::UnsupportedDataType`] if the data type is not supported. | ||
/// | ||
/// # Panics | ||
/// Panics if an element in the shape cannot be encoded in a `i64`. | ||
pub fn new( | ||
bytes: Arc<RawBytes<'static>>, | ||
representation: &ChunkRepresentation, | ||
) -> Result<Self, ArrayDlPackExtError> { | ||
let dtype = match representation.data_type() { | ||
DataType::Bool => dlpark::ffi::DataType::BOOL, | ||
DataType::Int8 => dlpark::ffi::DataType::I8, | ||
DataType::Int16 => dlpark::ffi::DataType::I16, | ||
DataType::Int32 => dlpark::ffi::DataType::I32, | ||
DataType::Int64 => dlpark::ffi::DataType::I64, | ||
DataType::UInt8 => dlpark::ffi::DataType::U8, | ||
DataType::UInt16 => dlpark::ffi::DataType::U16, | ||
DataType::UInt32 => dlpark::ffi::DataType::U32, | ||
DataType::UInt64 => dlpark::ffi::DataType::U64, | ||
DataType::Float16 => dlpark::ffi::DataType::F16, | ||
DataType::Float32 => dlpark::ffi::DataType::F32, | ||
DataType::Float64 => dlpark::ffi::DataType::F64, | ||
DataType::BFloat16 => dlpark::ffi::DataType::BF16, | ||
// TODO: Support extension data types with fallback? | ||
_ => Err(ArrayDlPackExtError::UnsupportedDataType)?, | ||
}; | ||
let shape = representation | ||
.shape() | ||
.iter() | ||
.map(|s| i64::try_from(s.get()).unwrap()) | ||
.collect(); | ||
Ok(Self { | ||
bytes, | ||
dtype, | ||
shape, | ||
}) | ||
} | ||
} | ||
|
||
impl ToTensor for RawBytesDlPack { | ||
fn data_ptr(&self) -> *mut c_void { | ||
self.bytes.as_ptr().cast::<c_void>().cast_mut() | ||
} | ||
|
||
fn byte_offset(&self) -> u64 { | ||
0 | ||
} | ||
|
||
fn device(&self) -> Device { | ||
Device::CPU | ||
} | ||
|
||
fn dtype(&self) -> dlpark::ffi::DataType { | ||
self.dtype | ||
} | ||
|
||
fn shape_and_strides(&self) -> ShapeAndStrides { | ||
ShapeAndStrides::new_contiguous(&self.shape) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use dlpark::{IntoDLPack, ManagedTensor}; | ||
use zarrs_data_type::{DataType, FillValue}; | ||
use zarrs_storage::store::MemoryStore; | ||
|
||
use crate::{ | ||
array::{codec::CodecOptions, ArrayBuilder, ArrayDlPackExt}, | ||
array_subset::ArraySubset, | ||
}; | ||
|
||
#[test] | ||
fn array_dlpack_ext_sync() { | ||
let store = MemoryStore::new(); | ||
let array = ArrayBuilder::new( | ||
vec![4, 4], | ||
DataType::Float32, | ||
vec![2, 2].try_into().unwrap(), | ||
FillValue::from(-1.0f32), | ||
) | ||
.build(store.into(), "/") | ||
.unwrap(); | ||
array | ||
.store_chunk_elements::<f32>(&[0, 0], &[0.0, 1.0, 2.0, 3.0]) | ||
.unwrap(); | ||
let tensor = array | ||
.retrieve_chunks_dlpack( | ||
&ArraySubset::new_with_shape(vec![1, 2]), | ||
&CodecOptions::default(), | ||
) | ||
.unwrap(); | ||
|
||
assert_eq!( | ||
ManagedTensor::new(tensor.into_dlpack()).as_slice::<f32>(), | ||
&[0.0, 1.0, -1.0, -1.0, 2.0, 3.0, -1.0, -1.0] | ||
); | ||
} | ||
} |
161 changes: 161 additions & 0 deletions
161
zarrs/src/array/array_dlpack_ext/array_dlpack_ext_async.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
use std::{num::NonZeroU64, sync::Arc}; | ||
|
||
use dlpark::ManagerCtx; | ||
use zarrs_storage::AsyncReadableStorageTraits; | ||
|
||
use crate::array::{codec::CodecOptions, Array, ArrayError, ChunkRepresentation}; | ||
use crate::array_subset::ArraySubset; | ||
|
||
use super::RawBytesDlPack; | ||
|
||
#[cfg(doc)] | ||
use super::ArrayDlPackExtError; | ||
|
||
/// An async [`Array`] extension trait with methods that return `DLPack` managed tensors. | ||
#[cfg_attr(feature = "async", async_trait::async_trait)] | ||
pub trait AsyncArrayDlPackExt<TStorage: ?Sized + AsyncReadableStorageTraits + 'static>: | ||
private::Sealed | ||
{ | ||
/// Read and decode the `array_subset` of array into a `DLPack` tensor. | ||
/// | ||
/// See [`Array::retrieve_array_subset_opt`]. | ||
/// | ||
/// # Errors | ||
/// Returns a [`super::ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor. | ||
/// Otherwise returns standard [`Array::retrieve_array_subset_opt`] errors. | ||
async fn retrieve_array_subset_dlpack( | ||
&self, | ||
array_subset: &ArraySubset, | ||
options: &CodecOptions, | ||
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError>; | ||
|
||
/// Read and decode the chunk at `chunk_indices` into a `DLPack` tensor if it exists. | ||
/// | ||
/// See [`Array::retrieve_chunk_if_exists_opt`]. | ||
/// | ||
/// # Errors | ||
/// Returns a [`ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor. | ||
/// Otherwise returns standard [`Array::retrieve_chunk_if_exists_opt`] errors. | ||
async fn retrieve_chunk_if_exists_dlpack( | ||
&self, | ||
chunk_indices: &[u64], | ||
options: &CodecOptions, | ||
) -> Result<Option<ManagerCtx<RawBytesDlPack>>, ArrayError>; | ||
|
||
/// Read and decode the chunk at `chunk_indices` into a `DLPack` tensor. | ||
/// | ||
/// See [`Array::retrieve_chunk_opt`]. | ||
/// | ||
/// # Errors | ||
/// Returns a [`ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor. | ||
/// Otherwise returns standard [`Array::retrieve_chunk_opt`] errors. | ||
async fn retrieve_chunk_dlpack( | ||
&self, | ||
chunk_indices: &[u64], | ||
options: &CodecOptions, | ||
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError>; | ||
|
||
/// Read and decode the chunks at `chunks` into a `DLPack` tensor. | ||
/// | ||
/// See [`Array::retrieve_chunks_opt`]. | ||
/// | ||
/// # Errors | ||
/// Returns a [`ArrayDlPackExtError`] if the chunk cannot be represented as a `DLPack` tensor. | ||
/// Otherwise returns standard [`Array::retrieve_chunks_opt`] errors. | ||
async fn retrieve_chunks_dlpack( | ||
&self, | ||
chunks: &ArraySubset, | ||
options: &CodecOptions, | ||
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError>; | ||
} | ||
|
||
#[cfg_attr(feature = "async", async_trait::async_trait)] | ||
impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> AsyncArrayDlPackExt<TStorage> | ||
for Array<TStorage> | ||
{ | ||
async fn retrieve_array_subset_dlpack( | ||
&self, | ||
array_subset: &ArraySubset, | ||
options: &CodecOptions, | ||
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError> { | ||
let bytes = self | ||
.async_retrieve_array_subset_opt(array_subset, options) | ||
.await? | ||
.into_owned(); | ||
let bytes = Arc::new(bytes.into_fixed()?); | ||
|
||
let representation = unsafe { | ||
// SAFETY: the data type and fill value are confirmed compatible | ||
ChunkRepresentation::new_unchecked( | ||
array_subset | ||
.shape() | ||
.iter() | ||
.map(|s| NonZeroU64::new(*s)) | ||
.collect::<Option<Vec<_>>>() | ||
.ok_or(ArrayError::InvalidArraySubset( | ||
array_subset.clone(), | ||
self.shape().to_vec(), | ||
))?, | ||
self.data_type().clone(), | ||
self.fill_value().clone(), | ||
) | ||
}; | ||
|
||
Ok(ManagerCtx::new( | ||
RawBytesDlPack::new(bytes, &representation).map_err(ArrayError::DlPackError)?, | ||
)) | ||
} | ||
|
||
async fn retrieve_chunk_if_exists_dlpack( | ||
&self, | ||
chunk_indices: &[u64], | ||
options: &CodecOptions, | ||
) -> Result<Option<ManagerCtx<RawBytesDlPack>>, ArrayError> { | ||
let Some(bytes) = self | ||
.async_retrieve_chunk_if_exists_opt(chunk_indices, options) | ||
.await? | ||
else { | ||
return Ok(None); | ||
}; | ||
let bytes = bytes.into_owned(); | ||
let bytes = Arc::new(bytes.into_fixed()?); | ||
let representation = self.chunk_array_representation(chunk_indices)?; | ||
Ok(Some(ManagerCtx::new( | ||
RawBytesDlPack::new(bytes, &representation).map_err(ArrayError::DlPackError)?, | ||
))) | ||
} | ||
|
||
async fn retrieve_chunk_dlpack( | ||
&self, | ||
chunk_indices: &[u64], | ||
options: &CodecOptions, | ||
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError> { | ||
let bytes = self | ||
.async_retrieve_chunk_opt(chunk_indices, options) | ||
.await? | ||
.into_owned(); | ||
let bytes = Arc::new(bytes.into_fixed()?); | ||
let representation = self.chunk_array_representation(chunk_indices)?; | ||
Ok(ManagerCtx::new( | ||
RawBytesDlPack::new(bytes, &representation).map_err(ArrayError::DlPackError)?, | ||
)) | ||
} | ||
|
||
async fn retrieve_chunks_dlpack( | ||
&self, | ||
chunks: &ArraySubset, | ||
options: &CodecOptions, | ||
) -> Result<ManagerCtx<RawBytesDlPack>, ArrayError> { | ||
let array_subset = self.chunks_subset(chunks)?; | ||
self.retrieve_array_subset_dlpack(&array_subset, options) | ||
.await | ||
} | ||
} | ||
|
||
mod private { | ||
use super::{Array, AsyncReadableStorageTraits}; | ||
|
||
pub trait Sealed {} | ||
|
||
impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Sealed for Array<TStorage> {} | ||
} |
Oops, something went wrong.