diff --git a/README.md b/README.md
index 0b6e6575..003eac48 100644
--- a/README.md
+++ b/README.md
@@ -174,7 +174,7 @@ ld.map(
**Key benefits:**
-✅ Paralellize processing: Reduce processing time by transforming data across multiple machines simultaneously.
+✅ Parallelize processing: Reduce processing time by transforming data across multiple machines simultaneously.
✅ Scale to large data: Increase the size of datasets you can efficiently handle.
✅ Flexible usecases: Resize images, create embeddings, scrape the internet, etc...
✅ Run local or cloud: Run on your own machines or auto-scale to 1000s of cloud GPUs with Lightning Studios.
@@ -638,7 +638,7 @@ Time to optimize 1.2 million ImageNet images (Faster is better):
----
-# Paralellize transforms and data optimization on cloud machines
+# Parallelize transforms and data optimization on cloud machines
diff --git a/requirements/test.txt b/requirements/test.txt
index d8240755..71de5307 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1,4 +1,5 @@
coverage ==7.5.3
+mosaicml-streaming==0.7.6
pytest ==8.2.*
pytest-cov ==5.0.0
pytest-timeout ==2.3.1
diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py
index d44662ae..d2220a00 100644
--- a/src/litdata/processing/data_processor.py
+++ b/src/litdata/processing/data_processor.py
@@ -51,6 +51,7 @@
from litdata.streaming.resolver import _resolve_dir
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
from litdata.utilities.broadcast import broadcast_object
+from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.packing import _pack_greedily
if _TQDM_AVAILABLE:
@@ -788,8 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
self._upload_index(output_dir, cache_dir, num_nodes, node_rank)
if num_nodes == node_rank + 1:
- with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f:
- config = json.load(f)
+ config = load_index_file(cache_dir)
size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py
index c5966552..b9744f0a 100644
--- a/src/litdata/streaming/config.py
+++ b/src/litdata/streaming/config.py
@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
from typing import Any, Dict, List, Optional, Tuple
@@ -22,6 +21,7 @@
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import tree_unflatten, treespec_loads
+from litdata.utilities.dataset_utilities import load_index_file
class ChunksConfig:
@@ -53,18 +53,18 @@ def __init__(
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()
- with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
- data = json.load(f)
- _original_chunks = data["chunks"]
- self._config = data["config"]
- self._validate_item_loader()
+ # load data from `index.json` file
+ data = load_index_file(self._cache_dir)
+ _original_chunks = data["chunks"]
+ self._config = data["config"]
+ self._validate_item_loader()
- assert _original_chunks is not None
+ assert _original_chunks is not None
- if subsampled_files is None:
- self._chunks = _original_chunks
- else:
- self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)
+ if subsampled_files is None:
+ self._chunks = _original_chunks
+ else:
+ self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)
self._config["data_spec"] = treespec_loads(self._config["data_spec"])
diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py
index d302214d..2213aaca 100644
--- a/src/litdata/streaming/item_loader.py
+++ b/src/litdata/streaming/item_loader.py
@@ -142,8 +142,32 @@ def load_item_from_chunk(
fp.seek(begin)
data = fp.read(end - begin)
+ # check for mosaic mds format
+ if "format" in self._config and self._config["format"] == "mds":
+ return self.mds_deserialize(data, chunk_index)
return self.deserialize(data)
+ def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree":
+ """Deserialize the mds raw bytes into their python equivalent."""
+ idx = 0
+ sizes = []
+ column_sizes = self._chunks[chunk_index]["column_sizes"]
+ # adapted from: MDSReader.deserialize : https://github.com/mosaicml/streaming/blob/main/streaming/base/format/mds/reader.py
+ for size in column_sizes:
+ if size:
+ sizes.append(size)
+ else:
+ (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32)
+ sizes.append(size)
+ idx += 4
+ data = []
+ for size, data_format in zip(sizes, self._data_format):
+ serializer = self._serializers[data_format]
+ data_bytes = raw_item_data[idx : idx + size]
+ data.append(serializer.deserialize(data_bytes))
+ idx += size
+ return tree_unflatten(data, self._config["data_spec"])
+
def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = self._shift_idx
diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py
index 65978e66..7e0c4838 100644
--- a/src/litdata/utilities/dataset_utilities.py
+++ b/src/litdata/utilities/dataset_utilities.py
@@ -51,9 +51,8 @@ def subsample_streaming_dataset(
if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
- with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
- data = json.load(f)
- original_chunks = data["chunks"]
+ data = load_index_file(input_dir.path)
+ original_chunks = data["chunks"]
else:
raise ValueError(
f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file."
@@ -115,3 +114,76 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa
roi.append((0, end))
return roi
+
+
+def load_index_file(input_dir: str) -> Dict[str, Any]:
+ """Load index file from the specified input directory.
+
+ This function supports loading both chunk-based and mds shard-based index files.
+ For shard-based files, it adapts the format to be compatible with chunk-based processing.
+
+ Args:
+ input_dir (str): The directory containing the index file.
+
+ Returns:
+ Dict[str, Any]: The loaded and possibly adapted index data.
+
+ Raises:
+ FileNotFoundError: If the index file does not exist in the input directory.
+
+ """
+ index_filepath = os.path.join(input_dir, _INDEX_FILENAME)
+ try:
+ with open(index_filepath) as f:
+ data = json.load(f)
+
+ if "chunks" not in data and "shards" in data:
+ # load mds shard-based index file and adapt to chunks format
+ return adapt_mds_shards_to_chunks(data)
+
+ return data
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Index file not found at {index_filepath}.")
+
+
+def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]:
+ """Adapt mds shard-based index data to chunk-based format for compatibility.
+ For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming
+
+ Args:
+ data (Dict[str, Any]): The original index data containing shards.
+
+ Returns:
+ Dict[str, Any]: Adapted index data with chunks format.
+ """
+ chunks = []
+ shards = data["shards"]
+ for shard in shards:
+ chunks.append(
+ {
+ "chunk_bytes": shard["zip_data"]["bytes"],
+ "chunk_size": shard["samples"],
+ "column_sizes": shard["column_sizes"],
+ "dim": None,
+ "filename": shard["zip_data"]["basename"],
+ }
+ )
+ data["chunks"] = chunks
+
+ data_spec = [
+ 1,
+ {
+ "type": "builtins.dict",
+ "context": json.dumps(shards[0]["column_names"]),
+ "children_spec": [{"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"]],
+ },
+ ]
+ data["config"] = {
+ "chunk_bytes": sum(shard["zip_data"]["bytes"] for shard in shards),
+ "chunk_size": sum(shard["samples"] for shard in shards),
+ "compression": shards[0]["compression"],
+ "data_format": shards[0]["column_encodings"],
+ "format": shards[0]["format"],
+ "data_spec": json.dumps(data_spec),
+ }
+ return data
diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py
index 7f8fbef9..d31fb076 100644
--- a/src/litdata/utilities/train_test_split.py
+++ b/src/litdata/utilities/train_test_split.py
@@ -1,4 +1,3 @@
-import json
import logging
import os
from copy import deepcopy
@@ -8,6 +7,7 @@
from litdata import StreamingDataset
from litdata.constants import _INDEX_FILENAME
+from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi
@@ -55,14 +55,12 @@ def train_test_split(
if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
- with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
- data = json.load(f)
- original_chunks = data["chunks"]
- subsampled_chunks = [
- _org_chunk
- for _org_chunk in original_chunks
- if _org_chunk["filename"] in dummy_subsampled_chunk_filename
- ]
+ data = load_index_file(input_dir.path)
+
+ original_chunks = data["chunks"]
+ subsampled_chunks = [
+ _org_chunk for _org_chunk in original_chunks if _org_chunk["filename"] in dummy_subsampled_chunk_filename
+ ]
else:
raise ValueError("Couldn't load original chunk file.")
diff --git a/tests/conftest.py b/tests/conftest.py
index 4133e793..538d0bcb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,3 +8,25 @@ def teardown_process_group(): # noqa: PT004
yield
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
+
+
+@pytest.fixture()
+def mosaic_mds_index_data():
+ return {
+ "shards": [
+ {
+ "column_encodings": ["int", "jpeg"],
+ "column_names": ["class", "image"],
+ "column_sizes": [8, None],
+ "compression": "zstd",
+ "format": "mds",
+ "hashes": [],
+ "raw_data": {"basename": "shard.00000.mds", "bytes": 125824, "hashes": {}},
+ "samples": 100,
+ "size_limit": 67108864,
+ "version": 2,
+ "zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 63407, "hashes": {}},
+ }
+ ],
+ "version": 2,
+ }
diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py
index 193b3202..91de1a95 100644
--- a/tests/streaming/test_dataset.py
+++ b/tests/streaming/test_dataset.py
@@ -21,6 +21,7 @@
import numpy as np
import pytest
import torch
+from litdata import train_test_split
from litdata.constants import _ZSTD_AVAILABLE
from litdata.processing import functions
from litdata.streaming import Cache
@@ -995,3 +996,69 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch):
)
assert len(dataset2) == int(len(dataset1) * 0.4)
+
+
+@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
+def test_dataset_with_mosaic_mds_data(tmpdir):
+ from PIL import Image
+ from streaming import MDSWriter
+ # example taken from: https://github.com/mosaicml/streaming
+
+ # A dictionary mapping input fields to their data types
+ columns = {"image": "jpeg", "class": "int"}
+ # Shard compression, if any
+ compression = "zstd"
+ # Save the samples as shards using MDSWriter
+ with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out:
+ for i in range(10):
+ sample = {
+ "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)),
+ "class": i,
+ }
+ out.write(sample)
+ dataset = StreamingDataset(input_dir=str(tmpdir))
+ assert len(dataset) == 10
+ for i in range(10):
+ sample = dataset[i]
+ assert sample["class"] == i
+
+ assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing
+
+ # -------------- train_test_split --------------
+
+ train_ds, test_ds, val_ds = train_test_split(dataset, splits=[0.7, 0.2, 0.1])
+
+ assert len(train_ds) == 7
+ assert len(test_ds) == 2
+ assert len(val_ds) == 1
+
+ # -------------- subsample --------------
+
+ dataset = StreamingDataset(input_dir=str(tmpdir), subsample=0.4)
+ assert len(dataset) == 4
+ assert [sample["class"] for sample in dataset[:]] == [0, 1, 2, 3]
+
+ # -------------- works with dataloader --------------
+
+ dataset = StreamingDataset(input_dir=str(tmpdir))
+ dataloader = DataLoader(dataset, batch_size=4, drop_last=True)
+ i = 0
+ for batch in dataloader:
+ assert len(batch["class"]) == 4
+ assert len(batch["image"]) == 4
+ assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
+ i += 1
+
+ dataloader = DataLoader(dataset, batch_size=4, drop_last=False)
+ i = 0
+ for batch in dataloader:
+ if i == 2:
+ # last batch is smaller than batch_size
+ assert len(batch["class"]) == 2
+ assert len(batch["image"]) == 2
+ assert list(batch["class"]) == [4 * i, 4 * i + 1]
+ break
+ assert len(batch["class"]) == 4
+ assert len(batch["image"]) == 4
+ assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
+ i += 1
diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py
index 40b3cfce..bb952fe9 100644
--- a/tests/utilities/test_dataset_utilities.py
+++ b/tests/utilities/test_dataset_utilities.py
@@ -1,10 +1,14 @@
+import json
import os
from unittest import mock
+from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import (
_should_replace_path,
_try_create_cache_dir,
+ adapt_mds_shards_to_chunks,
generate_roi,
+ load_index_file,
)
@@ -44,3 +48,19 @@ def test_generate_roi():
my_roi = generate_roi(my_chunks)
assert my_roi == [(0, 30), (0, 50), (0, 20), (0, 10)]
+
+
+def test_load_index_file(tmpdir, mosaic_mds_index_data):
+ with open(os.path.join(tmpdir, _INDEX_FILENAME), "w") as f:
+ f.write(json.dumps(mosaic_mds_index_data))
+ index_data = load_index_file(tmpdir)
+ assert "chunks" in index_data
+ assert "config" in index_data
+ assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"])
+
+
+def test_adapt_mds_shards_to_chunks(mosaic_mds_index_data):
+ adapted_data = adapt_mds_shards_to_chunks(mosaic_mds_index_data)
+ assert "chunks" in adapted_data
+ assert "config" in adapted_data
+ assert len(mosaic_mds_index_data["shards"]) == len(adapted_data["chunks"])