Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: adds support for reading mosaic mds written dataset #210

Merged
merged 28 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
298e1a5
chore: update test.txt with mosaicml-streaming dependency
bhimrazy Jul 6, 2024
3f15c36
feat: add load_index_file function with supports for mds config
bhimrazy Jul 6, 2024
eb05ee7
chore: replaces indexfile loading with reusable fn
bhimrazy Jul 6, 2024
66da1e5
feat: updates config to load indexfile
bhimrazy Jul 6, 2024
615c894
feat: adds fn to deserialize mds written bytes data
bhimrazy Jul 6, 2024
0e690f6
feat: adds tests to test the functionality to read mds writer dataset
bhimrazy Jul 6, 2024
1584532
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2024
13428ff
fix: import path for `load_index_file` fn
bhimrazy Jul 6, 2024
3fe0b56
fixes:type
bhimrazy Jul 6, 2024
e3773b0
fix: `load_index_file` input dir
bhimrazy Jul 6, 2024
c8c88b9
fix: return type
bhimrazy Jul 6, 2024
8248410
feat: adds default missing return case
bhimrazy Jul 6, 2024
2dedd0b
Merge branch 'main' into feat/adds-mosaic-mds-support
bhimrazy Jul 6, 2024
281329f
Merge branch 'main' into feat/adds-mosaic-mds-support
bhimrazy Jul 6, 2024
d0aa2ab
Merge branch 'main' into feat/adds-mosaic-mds-support
bhimrazy Jul 7, 2024
5a38800
Update README.md: fix typo in parallelize
bhimrazy Jul 7, 2024
bfe93a8
chore: updates test for mds dataset
bhimrazy Jul 7, 2024
f9dfe25
refactor: Improve index file loading and adapt MDS shards to chunks f…
bhimrazy Jul 7, 2024
1f76ffd
chore: Add unit test for adapting MDS shards to chunks format
bhimrazy Jul 7, 2024
f26dc32
refactor: Skip test_dataset_with_mosaic_mds_data on Windows
bhimrazy Jul 7, 2024
930eb1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2024
54c32d5
refactor: Improve index file loading and adapt MDS shards to chunks f…
bhimrazy Jul 7, 2024
23e6fdf
test streamingDataset features for mosaic mds
deependujha Jul 8, 2024
ce6db09
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2024
b25eeff
fix pre-commit-ci errors
deependujha Jul 8, 2024
30a6ce8
fix pre-commit-ci list comprehension with yield
deependujha Jul 8, 2024
575615c
fix failing tests bcoz of generators
deependujha Jul 8, 2024
6b09b80
fix: docs for the fn
bhimrazy Jul 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
coverage ==7.5.3
mosaicml-streaming==0.7.6
pytest ==8.2.*
pytest-cov ==5.0.0
pytest-timeout ==2.3.1
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from litdata.streaming.resolver import _resolve_dir
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.packing import _pack_greedily

if _TQDM_AVAILABLE:
Expand Down Expand Up @@ -788,8 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
self._upload_index(output_dir, cache_dir, num_nodes, node_rank)

if num_nodes == node_rank + 1:
with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f:
config = json.load(f)
config = load_index_file(cache_dir)

size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
Expand Down
22 changes: 11 additions & 11 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -22,6 +21,7 @@
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import tree_unflatten, treespec_loads
from litdata.utilities.dataset_utilities import load_index_file


class ChunksConfig:
Expand Down Expand Up @@ -53,18 +53,18 @@ def __init__(
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()

with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()
# load data from `index.json` file
data = load_index_file(self._cache_dir)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()

assert _original_chunks is not None
assert _original_chunks is not None

if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)
if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)

self._config["data_spec"] = treespec_loads(self._config["data_spec"])

Expand Down
24 changes: 24 additions & 0 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,32 @@ def load_item_from_chunk(
fp.seek(begin)
data = fp.read(end - begin)

# check for mosaic mds format
if "format" in self._config and self._config["format"] == "mds":
return self.mds_deserialize(data, chunk_index)
return self.deserialize(data)

def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree":
"""Deserialize the mds raw bytes into their python equivalent."""
idx = 0
sizes = []
column_sizes = self._chunks[chunk_index]["column_sizes"]
# adapted from: MDSReader.deserialize : https://github.com/mosaicml/streaming/blob/main/streaming/base/format/mds/reader.py
for size in column_sizes:
if size:
sizes.append(size)
else:
(size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32)
sizes.append(size)
idx += 4
data = []
for size, data_format in zip(sizes, self._data_format):
serializer = self._serializers[data_format]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
return tree_unflatten(data, self._config["data_spec"])

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = self._shift_idx
Expand Down
59 changes: 56 additions & 3 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def subsample_streaming_dataset(

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
data = json.load(f)
original_chunks = data["chunks"]
data = load_index_file(input_dir.path)
original_chunks = data["chunks"]
else:
raise ValueError(
f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file."
Expand Down Expand Up @@ -115,3 +114,57 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa
roi.append((0, end))

return roi


def load_index_file(input_dir: str) -> Dict[str, Any]:
"""Load index file from the input_dir."""

index_filepath = os.path.join(input_dir, _INDEX_FILENAME)
try:
# load index.json file
with open(index_filepath) as f:
data = json.load(f)
if "chunks" not in data:
raise KeyError(f"'chunks' not found in the index file at {index_filepath}.")
return data
except KeyError as e:
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
# Verify the presence of MDS shards
# For more details, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming
if "shards" in data:
# adapt mosiac index to litdata index
chunks = []
shards = data["shards"]
for shard in shards:
chunks.append(
{
"chunk_bytes": shard["zip_data"]["bytes"],
"chunk_size": shard["samples"],
"column_sizes": shard["column_sizes"],
"dim": None,
"filename": shard["zip_data"]["basename"],
}
)
data["chunks"] = chunks
# TODO: create a robust data_spec
data_spec = [
1,
{
"type": "builtins.dict",
"context": json.dumps(shards[0]["column_names"]),
"children_spec": [
{"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"]
],
},
]
data["config"] = {
"chunk_bytes": sum([shard["zip_data"]["bytes"] for shard in shards]),
"chunk_size": sum([shard["samples"] for shard in shards]),
"compression": shards[0]["compression"],
"data_format": shards[0]["column_encodings"],
"format": shards[0]["format"],
"data_spec": json.dumps(data_spec),
}
return data
raise e
except FileNotFoundError:
raise FileNotFoundError(f"Index file not found at {index_filepath}.")
16 changes: 7 additions & 9 deletions src/litdata/utilities/train_test_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
from copy import deepcopy
Expand All @@ -8,6 +7,7 @@

from litdata import StreamingDataset
from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi


Expand Down Expand Up @@ -55,14 +55,12 @@ def train_test_split(

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
data = json.load(f)
original_chunks = data["chunks"]
subsampled_chunks = [
_org_chunk
for _org_chunk in original_chunks
if _org_chunk["filename"] in dummy_subsampled_chunk_filename
]
data = load_index_file(input_dir.path)

original_chunks = data["chunks"]
subsampled_chunks = [
_org_chunk for _org_chunk in original_chunks if _org_chunk["filename"] in dummy_subsampled_chunk_filename
]
else:
raise ValueError("Couldn't load original chunk file.")

Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,25 @@ def teardown_process_group(): # noqa: PT004
yield
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


@pytest.fixture()
def mosaic_mds_index_data():
return {
"shards": [
{
"column_encodings": ["int", "jpeg"],
"column_names": ["class", "image"],
"column_sizes": [8, None],
"compression": "zstd",
"format": "mds",
"hashes": [],
"raw_data": {"basename": "shard.00000.mds", "bytes": 125824, "hashes": {}},
"samples": 100,
"size_limit": 67108864,
"version": 2,
"zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 63407, "hashes": {}},
}
],
"version": 2,
}
24 changes: 24 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,27 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch):
)

assert len(dataset2) == int(len(dataset1) * 0.4)


def test_dataset_with_mosaic_mds_data(tmpdir):
from PIL import Image
from streaming import MDSWriter
# example taken from: https://github.com/mosaicml/streaming

# A dictionary mapping input fields to their data types
columns = {"image": "jpeg", "class": "int"}
# Shard compression, if any
compression = "zstd"
# Save the samples as shards using MDSWriter
with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out:
for i in range(100):
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
sample = {
"image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)),
"class": i,
}
out.write(sample)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert len(dataset) == 100
for i in range(100):
sample = dataset[i]
assert sample["class"] == i
12 changes: 12 additions & 0 deletions tests/utilities/test_dataset_utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import os
from unittest import mock

from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import (
_should_replace_path,
_try_create_cache_dir,
generate_roi,
load_index_file,
)


Expand Down Expand Up @@ -44,3 +47,12 @@ def test_generate_roi():
my_roi = generate_roi(my_chunks)

assert my_roi == [(0, 30), (0, 50), (0, 20), (0, 10)]


def test_load_index_file(tmpdir, mosaic_mds_index_data):
with open(os.path.join(tmpdir, _INDEX_FILENAME), "w") as f:
f.write(json.dumps(mosaic_mds_index_data))
index_data = load_index_file(tmpdir)
assert "chunks" in index_data
assert "config" in index_data
assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"])
Loading