Skip to content

Commit

Permalink
Feat: add support for custom cache dir in Streaming Dataset (#399)
Browse files Browse the repository at this point in the history
* adds default lightning cache dir

* adds support for cache dir

* adds test_try_create_cache_dir_with_custom_cache_dir

* fixed types

* simplified

* reverted change

* adds cache_dir_path in test with statedict
  • Loading branch information
bhimrazy authored Oct 28, 2024
1 parent 62907b3 commit 32194cd
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
14 changes: 11 additions & 3 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class StreamingDataset(IterableDataset):
def __init__(
self,
input_dir: Union[str, "Dir"],
cache_dir: Optional[Union[str, "Dir"]] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: Optional[bool] = None,
Expand All @@ -61,6 +62,8 @@ def __init__(
Args:
input_dir: Path to the folder where the input data is stored.
cache_dir: Path to the folder where the cache data is stored. If not provided, the cache will be stored
in the default cache directory.
item_loader: The logic to load an item from a chunk.
shuffle: Whether to shuffle the data.
drop_last: If `True`, drops the last items to ensure that
Expand All @@ -84,12 +87,14 @@ def __init__(
raise ValueError("subsample must be a float with value between 0 and 1.")

input_dir = _resolve_dir(input_dir)
cache_dir = _resolve_dir(cache_dir)

self.input_dir = input_dir
self.cache_dir = cache_dir
self.subsampled_files: List[str] = []
self.region_of_interest: List[Tuple[int, int]] = []
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
self.input_dir, item_loader, subsample, shuffle, seed, storage_options
self.input_dir, self.cache_dir, item_loader, subsample, shuffle, seed, storage_options
)

self.item_loader = item_loader
Expand Down Expand Up @@ -155,7 +160,8 @@ def set_epoch(self, current_epoch: int) -> None:
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
if _should_replace_path(self.input_dir.path):
cache_path = _try_create_cache_dir(
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url,
cache_dir=self.cache_dir.path,
)
if cache_path is not None:
self.input_dir.path = cache_path
Expand Down Expand Up @@ -399,6 +405,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
"cache_dir_path": self.cache_dir.path,
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
"drop_last": self.drop_last,
"seed": self.seed,
Expand Down Expand Up @@ -438,7 +445,8 @@ def _validate_state_dict(self) -> None:
# In this case, validate the cache folder is the same.
if _should_replace_path(state["input_dir_path"]):
cache_path = _try_create_cache_dir(
input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"]
input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"],
cache_dir=state.get("cache_dir_path"),
)
if cache_path != self.input_dir.path:
raise ValueError(
Expand Down
17 changes: 12 additions & 5 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from litdata.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME
from litdata.constants import _DEFAULT_CACHE_DIR, _DEFAULT_LIGHTNING_CACHE_DIR, _INDEX_FILENAME
from litdata.streaming.downloader import get_downloader_cls
from litdata.streaming.item_loader import BaseItemLoader, TokensLoader
from litdata.streaming.resolver import Dir, _resolve_dir
Expand All @@ -17,6 +17,7 @@

def subsample_streaming_dataset(
input_dir: Dir,
cache_dir: Optional[Dir] = None,
item_loader: Optional[BaseItemLoader] = None,
subsample: float = 1.0,
shuffle: bool = False,
Expand All @@ -39,7 +40,9 @@ def subsample_streaming_dataset(
# Make sure input_dir contains cache path and remote url
if _should_replace_path(input_dir.path):
cache_path = _try_create_cache_dir(
input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options
input_dir=input_dir.path if input_dir.path else input_dir.url,
cache_dir=cache_dir.path if cache_dir else None,
storage_options=storage_options,
)
if cache_path is not None:
input_dir.path = cache_path
Expand Down Expand Up @@ -137,7 +140,11 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s
shutil.rmtree(input_dir_hash_filepath)


def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]:
def _try_create_cache_dir(
input_dir: Optional[str],
cache_dir: Optional[str] = None,
storage_options: Optional[Dict] = {},
) -> Optional[str]:
resolved_input_dir = _resolve_dir(input_dir)
updated_at = _read_updated_at(resolved_input_dir, storage_options)

Expand All @@ -147,13 +154,13 @@ def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Di
dir_url_hash = hashlib.md5((resolved_input_dir.url or "").encode()).hexdigest() # noqa: S324

if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
input_dir_hash_filepath = os.path.join(_DEFAULT_CACHE_DIR, dir_url_hash)
input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_CACHE_DIR, dir_url_hash)
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

input_dir_hash_filepath = os.path.join("/cache", "chunks", dir_url_hash)
input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_LIGHTNING_CACHE_DIR, dir_url_hash)
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
os.makedirs(cache_dir, exist_ok=True)
Expand Down
28 changes: 28 additions & 0 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -410,6 +411,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -432,6 +434,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -447,6 +450,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -469,6 +473,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -484,6 +489,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -506,6 +512,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -521,6 +528,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -543,6 +551,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -558,6 +567,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -580,6 +590,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -595,6 +606,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -617,6 +629,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -632,6 +645,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -657,6 +671,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -672,6 +687,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -694,6 +710,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -709,6 +726,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -731,6 +749,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -746,6 +765,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -768,6 +788,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -783,6 +804,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -805,6 +827,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -820,6 +843,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -842,6 +866,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -857,6 +882,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -879,6 +905,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -894,6 +921,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand Down
20 changes: 20 additions & 0 deletions tests/utilities/test_dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ def test_try_create_cache_dir():
assert len(makedirs_mock.mock_calls) == 2


def test_try_create_cache_dir_with_custom_cache_dir(tmpdir):
cache_dir = str(tmpdir.join("cache"))
with mock.patch.dict(os.environ, {}, clear=True):
assert os.path.join(
cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "100b8cad7cf2a56f6df78f171f97a1ec"
) in _try_create_cache_dir("any", cache_dir)

with (
mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}),
mock.patch("litdata.streaming.dataset.os.makedirs") as makedirs_mock,
):
cache_dir_1 = _try_create_cache_dir("", cache_dir)
cache_dir_2 = _try_create_cache_dir("ssdf", cache_dir)
assert cache_dir_1 != cache_dir_2
assert cache_dir_1 == os.path.join(
cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "d41d8cd98f00b204e9800998ecf8427e"
)
assert len(makedirs_mock.mock_calls) == 2


def test_generate_roi():
my_chunks = [
{"chunk_size": 30},
Expand Down

0 comments on commit 32194cd

Please sign in to comment.