Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add support for numpy datatypes in tokensloader #401

23 changes: 18 additions & 5 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import numpy as np
import torch

from litdata.constants import (
_TORCH_DTYPES_MAPPING,
)
from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption, EncryptionLevel
Expand Down Expand Up @@ -281,7 +279,17 @@ def setup(
region_of_interest: Optional[List[Tuple[int, int]]] = None,
) -> None:
super().setup(config, chunks, serializers, region_of_interest)
self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])]

serializer_name, dtype_index = self._data_format[0].split(":")
if serializer_name not in ["no_header_numpy", "no_header_tensor"]:
raise ValueError("The provided data format isn't supported.")

self._serializer_name = serializer_name
self._dtype = (
_TORCH_DTYPES_MAPPING[int(dtype_index)] # type: ignore
if serializer_name == "no_header_tensor"
else _NUMPY_DTYPES_MAPPING[int(dtype_index)]
)
if all(chunk["dim"] is None for chunk in self._chunks):
raise ValueError("The provided chunks isn't properly setup.")

Expand Down Expand Up @@ -350,7 +358,12 @@ def load_item_from_chunk(

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

if self._serializer_name == "no_header_tensor":
data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
else:
data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore
return data

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
Expand Down
32 changes: 30 additions & 2 deletions tests/streaming/test_item_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import MagicMock

import numpy as np
import torch
from litdata.constants import _TORCH_DTYPES_MAPPING
from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming import Cache
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import PyTreeLoader
from litdata.streaming.item_loader import PyTreeLoader, TokensLoader


def test_serializer_setup():
Expand Down Expand Up @@ -38,3 +39,30 @@ def test_pytreeloader_with_no_header_tensor_serializer(tmpdir):
item = dataset[i]
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_float]), item["float"])
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_long]), item["long"])


def test_tokensloader_with_no_header_numpy_serializer(tmpdir):
cache = Cache(str(tmpdir), chunk_size=512, item_loader=TokensLoader())
assert isinstance(cache._reader._item_loader, TokensLoader)

dtype_index_int32 = 3
dtype = _NUMPY_DTYPES_MAPPING[dtype_index_int32]

for i in range(10):
data = np.random.randint(0, 100, size=(256), dtype=dtype)
cache._add_item(i, data)

data_format = [f"no_header_numpy:{dtype_index_int32}"]
assert cache._writer.get_config()["data_format"] == data_format
cache.done()
cache.merge()

dataset = StreamingDataset(
input_dir=str(tmpdir),
drop_last=True,
item_loader=TokensLoader(block_size=256),
)

for data in dataset:
assert data.shape == (256,)
assert data.dtype == dtype
Loading