diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 43e981e9..9c58cabf 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -23,9 +23,7 @@ import numpy as np import torch -from litdata.constants import ( - _TORCH_DTYPES_MAPPING, -) +from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import PyTree, tree_unflatten from litdata.utilities.encryption import Encryption, EncryptionLevel @@ -281,7 +279,17 @@ def setup( region_of_interest: Optional[List[Tuple[int, int]]] = None, ) -> None: super().setup(config, chunks, serializers, region_of_interest) - self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])] + + serializer_name, dtype_index = self._data_format[0].split(":") + if serializer_name not in ["no_header_numpy", "no_header_tensor"]: + raise ValueError("The provided data format isn't supported.") + + self._serializer_name = serializer_name + self._dtype = ( + _TORCH_DTYPES_MAPPING[int(dtype_index)] # type: ignore + if serializer_name == "no_header_tensor" + else _NUMPY_DTYPES_MAPPING[int(dtype_index)] + ) if all(chunk["dim"] is None for chunk in self._chunks): raise ValueError("The provided chunks isn't properly setup.") @@ -350,7 +358,12 @@ def load_item_from_chunk( buffer: bytes = self._buffers[chunk_index] offset = self._dtype.itemsize * (index - begin) * self._block_size - return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) + + if self._serializer_name == "no_header_tensor": + data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) + else: + data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore + return data def delete(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath): diff --git a/tests/streaming/test_item_loader.py b/tests/streaming/test_item_loader.py index ecb8e6f8..a5828b24 100644 --- a/tests/streaming/test_item_loader.py +++ b/tests/streaming/test_item_loader.py @@ -1,10 +1,11 @@ from unittest.mock import MagicMock +import numpy as np import torch -from litdata.constants import _TORCH_DTYPES_MAPPING +from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING from litdata.streaming import Cache from litdata.streaming.dataset import StreamingDataset -from litdata.streaming.item_loader import PyTreeLoader +from litdata.streaming.item_loader import PyTreeLoader, TokensLoader def test_serializer_setup(): @@ -38,3 +39,30 @@ def test_pytreeloader_with_no_header_tensor_serializer(tmpdir): item = dataset[i] assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_float]), item["float"]) assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_long]), item["long"]) + + +def test_tokensloader_with_no_header_numpy_serializer(tmpdir): + cache = Cache(str(tmpdir), chunk_size=512, item_loader=TokensLoader()) + assert isinstance(cache._reader._item_loader, TokensLoader) + + dtype_index_int32 = 3 + dtype = _NUMPY_DTYPES_MAPPING[dtype_index_int32] + + for i in range(10): + data = np.random.randint(0, 100, size=(256), dtype=dtype) + cache._add_item(i, data) + + data_format = [f"no_header_numpy:{dtype_index_int32}"] + assert cache._writer.get_config()["data_format"] == data_format + cache.done() + cache.merge() + + dataset = StreamingDataset( + input_dir=str(tmpdir), + drop_last=True, + item_loader=TokensLoader(block_size=256), + ) + + for data in dataset: + assert data.shape == (256,) + assert data.dtype == dtype