Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zstd Codec on the GPU #2863

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ remote = [
]
gpu = [
"cupy-cuda12x",
"nvidia-nvcomp-cu12",
]
# Development extras
test = [
Expand Down
2 changes: 2 additions & 0 deletions src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle
from zarr.codecs.bytes import BytesCodec, Endian
from zarr.codecs.crc32c_ import Crc32cCodec
from zarr.codecs.gpu import NvcompZstdCodec
from zarr.codecs.gzip import GzipCodec
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
from zarr.codecs.transpose import TransposeCodec
Expand All @@ -17,6 +18,7 @@
"Crc32cCodec",
"Endian",
"GzipCodec",
"NvcompZstdCodec",
"ShardingCodec",
"ShardingCodecIndexLocation",
"TransposeCodec",
Expand Down
195 changes: 195 additions & 0 deletions src/zarr/codecs/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import asyncio
from collections.abc import Awaitable
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

import numpy as np

from zarr.abc.codec import BytesBytesCodec
from zarr.core.common import JSON, parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from typing import Any, Self

from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer

try:
import cupy as cp
except ImportError:
cp = None

try:
from nvidia import nvcomp
except ImportError:
nvcomp = None


class AsyncCUDAEvent(Awaitable[None]):
"""An awaitable wrapper around a CuPy CUDA event for asynchronous waiting."""

def __init__(
self, event: cp.cuda.Event, initial_delay: float = 0.001, max_delay: float = 0.1
) -> None:
"""
Initialize the async CUDA event.

Args:
event (cp.cuda.Event): The CuPy CUDA event to wait on.
initial_delay (float): Initial polling delay in seconds (default: 0.001s).
max_delay (float): Maximum polling delay in seconds (default: 0.1s).
"""
self.event = event
self.initial_delay = initial_delay
self.max_delay = max_delay

def __await__(self) -> Generator[Any, None, None]:
"""Makes the event awaitable by yielding control until the event is complete."""
return self._wait().__await__()

async def _wait(self) -> None:
"""Polls the CUDA event asynchronously with exponential backoff until it completes."""
delay = self.initial_delay
while not self.event.query(): # `query()` returns True if the event is complete
await asyncio.sleep(delay) # Yield control to other async tasks
delay = min(delay * 2, self.max_delay) # Exponential backoff


def parse_zstd_level(data: JSON) -> int:
if isinstance(data, int):
if data >= 23:
raise ValueError(f"Value must be less than or equal to 22. Got {data} instead.")
return data
raise TypeError(f"Got value with type {type(data)}, but expected an int.")


def parse_checksum(data: JSON) -> bool:
if isinstance(data, bool):
return data
raise TypeError(f"Expected bool. Got {type(data)}.")


@dataclass(frozen=True)
class NvcompZstdCodec(BytesBytesCodec):
is_fixed_size = True

level: int = 0
checksum: bool = False

def __init__(self, *, level: int = 0, checksum: bool = False) -> None:
# TODO: Set CUDA device appropriately here and also set CUDA stream

level_parsed = parse_zstd_level(level)
checksum_parsed = parse_checksum(checksum)

object.__setattr__(self, "level", level_parsed)
object.__setattr__(self, "checksum", checksum_parsed)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "zstd")
return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
return {
"name": "zstd",
"configuration": {"level": self.level, "checksum": self.checksum},
}

@cached_property
def _zstd_codec(self) -> nvcomp.Codec:
# config_dict = {algorithm = "Zstd", "level": self.level, "checksum": self.checksum}
# return Zstd.from_config(config_dict)
device = cp.cuda.Device() # Select the current default device
stream = cp.cuda.get_current_stream() # Use the current default stream
return nvcomp.Codec(
algorithm="Zstd",
bitstream_kind=nvcomp.BitstreamKind.RAW,
device_id=device.id,
cuda_stream=stream.ptr,
)

async def _convert_to_nvcomp_arrays(
self,
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
) -> tuple[list[nvcomp.Array], list[int]]:
none_indices = [i for i, (b, _) in enumerate(chunks_and_specs) if b is None]
filtered_inputs = [b.as_array_like() for b, _ in chunks_and_specs if b is not None]
# TODO: add CUDA stream here
return nvcomp.as_arrays(filtered_inputs), none_indices

async def _convert_from_nvcomp_arrays(
self,
arrays: Iterable[nvcomp.Array],
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
) -> Iterable[Buffer | None]:
return [
spec.prototype.buffer.from_array_like(cp.asarray(a, dtype=np.dtype("b"))) if a else None
for a, (_, spec) in zip(arrays, chunks_and_specs, strict=True)
]

async def decode(
self,
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
) -> Iterable[Buffer | None]:
"""Decodes a batch of chunks.
Chunks can be None in which case they are ignored by the codec.

Parameters
----------
chunks_and_specs : Iterable[tuple[Buffer | None, ArraySpec]]
Ordered set of encoded chunks with their accompanying chunk spec.

Returns
-------
Iterable[Buffer | None]
"""
chunks_and_specs = list(chunks_and_specs)

# Convert to nvcomp arrays
filtered_inputs, none_indices = await self._convert_to_nvcomp_arrays(chunks_and_specs)

outputs = self._zstd_codec.decode(filtered_inputs) if len(filtered_inputs) > 0 else []
for index in none_indices:
outputs.insert(index, None)

return await self._convert_from_nvcomp_arrays(outputs, chunks_and_specs)

async def encode(
self,
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
) -> Iterable[Buffer | None]:
"""Encodes a batch of chunks.
Chunks can be None in which case they are ignored by the codec.

Parameters
----------
chunks_and_specs : Iterable[tuple[Buffer | None, ArraySpec]]
Ordered set of to-be-encoded chunks with their accompanying chunk spec.

Returns
-------
Iterable[Buffer | None]
"""
# TODO: Make this actually async
chunks_and_specs = list(chunks_and_specs)

# Convert to nvcomp arrays
filtered_inputs, none_indices = await self._convert_to_nvcomp_arrays(chunks_and_specs)

outputs = self._zstd_codec.encode(filtered_inputs) if len(filtered_inputs) > 0 else []
for index in none_indices:
outputs.insert(index, None)

return await self._convert_from_nvcomp_arrays(outputs, chunks_and_specs)

def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int:
raise NotImplementedError


register_codec("zstd", NvcompZstdCodec)
21 changes: 17 additions & 4 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def enable_gpu(self) -> ConfigSet:
Configure Zarr to use GPUs where possible.
"""
return self.set(
{"buffer": "zarr.core.buffer.gpu.Buffer", "ndbuffer": "zarr.core.buffer.gpu.NDBuffer"}
{
"buffer": "zarr.core.buffer.gpu.Buffer",
"ndbuffer": "zarr.core.buffer.gpu.NDBuffer",
"codecs": {"zstd": "zarr.codecs.gpu.NvcompZstdCodec"},
}
)


Expand Down Expand Up @@ -96,13 +100,22 @@ def enable_gpu(self) -> ConfigSet:
},
"v3_default_compressors": {
"numeric": [
{"name": "zstd", "configuration": {"level": 0, "checksum": False}},
{
"name": "zstd",
"configuration": {"level": 0, "checksum": False},
},
],
"string": [
{"name": "zstd", "configuration": {"level": 0, "checksum": False}},
{
"name": "zstd",
"configuration": {"level": 0, "checksum": False},
},
],
"bytes": [
{"name": "zstd", "configuration": {"level": 0, "checksum": False}},
{
"name": "zstd",
"configuration": {"level": 0, "checksum": False},
},
],
},
},
Expand Down
7 changes: 5 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zarr.api.asynchronous
import zarr.core.group
from zarr import Array, Group
from zarr.abc.codec import Codec
from zarr.abc.store import Store
from zarr.api.synchronous import (
create,
Expand All @@ -23,6 +24,7 @@
save_array,
save_group,
)
from zarr.codecs import NvcompZstdCodec
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
from zarr.errors import MetadataValidationError
from zarr.storage import MemoryStore
Expand Down Expand Up @@ -1131,15 +1133,16 @@ def test_open_array_with_mode_r_plus(store: Store) -> None:
indirect=True,
)
@pytest.mark.parametrize("zarr_format", [None, 2, 3])
def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None) -> None:
@pytest.mark.parametrize("codec", ["auto", NvcompZstdCodec()])
def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None, codec: str | Codec) -> None:
import cupy as cp

if zarr_format == 2:
# Without this, the zstd codec attempts to convert the cupy
# array to bytes.
compressors = None
else:
compressors = "auto"
compressors = codec

with zarr.config.enable_gpu():
src = cp.random.uniform(size=(100, 100)) # allocate on the device
Expand Down