Skip to content

Commit

Permalink
azure storage options (#365)
Browse files Browse the repository at this point in the history
Co-authored-by: MohanReddy <[email protected]>
  • Loading branch information
mohanreddypmr and MohanReddy authored Sep 6, 2024
1 parent 92df8af commit 7efd761
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def subsample_streaming_dataset(

# Make sure input_dir contains cache path and remote url
if _should_replace_path(input_dir.path):
cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url)
cache_path = _try_create_cache_dir(
input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options
)
if cache_path is not None:
input_dir.path = cache_path

Expand Down Expand Up @@ -96,7 +98,7 @@ def _should_replace_path(path: Optional[str]) -> bool:
return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/")


def _read_updated_at(input_dir: Optional[Dir]) -> str:
def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str:
"""Read last updated timestamp from index.json file."""
last_updation_timestamp = "0"
index_json_content = None
Expand All @@ -110,7 +112,7 @@ def _read_updated_at(input_dir: Optional[Dir]) -> str:
# download index.json file and read last_updation_timestamp
with tempfile.TemporaryDirectory() as tmp_directory:
temp_index_filepath = os.path.join(tmp_directory, _INDEX_FILENAME)
downloader = get_downloader_cls(input_dir.url, tmp_directory, [])
downloader = get_downloader_cls(input_dir.url, tmp_directory, [], storage_options)
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath)

index_json_content = load_index_file(tmp_directory)
Expand All @@ -135,9 +137,9 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s
shutil.rmtree(input_dir_hash_filepath)


def _try_create_cache_dir(input_dir: Optional[str]) -> Optional[str]:
def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]:
resolved_input_dir = _resolve_dir(input_dir)
updated_at = _read_updated_at(resolved_input_dir)
updated_at = _read_updated_at(resolved_input_dir, storage_options)

if updated_at == "0" and input_dir is not None:
updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324
Expand Down

0 comments on commit 7efd761

Please sign in to comment.