diff --git a/README.md b/README.md index bab69ed9..470cace5 100644 --- a/README.md +++ b/README.md @@ -217,9 +217,8 @@ Additionally, you can inject client connection settings for [S3](https://boto3.a from litdata import StreamingDataset storage_options = { - "endpoint_url": "your_endpoint_url", - "aws_access_key_id": "your_access_key_id", - "aws_secret_access_key": "your_secret_access_key", + "key": "your_access_key_id", + "secret": "your_secret_access_key", } dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) @@ -264,7 +263,7 @@ for batch in val_dataloader:   -The StreamingDataset supports reading optimized datasets from common cloud providers. +The StreamingDataset supports reading optimized datasets from common cloud providers. ```python import os @@ -272,25 +271,39 @@ import litdata as ld # Read data from AWS S3 aws_storage_options={ - "AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'], - "AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'], + "key": os.environ['AWS_ACCESS_KEY_ID'], + "secret": os.environ['AWS_SECRET_ACCESS_KEY'], } dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) # Read data from GCS gcp_storage_options={ - "project": os.environ['PROJECT_ID'], + "token": { + # dumped from cat ~/.config/gcloud/application_default_credentials.json + "account": "", + "client_id": "your_client_id", + "client_secret": "your_client_secret", + "quota_project_id": "your_quota_project_id", + "refresh_token": "your_refresh_token", + "type": "authorized_user", + "universe_domain": "googleapis.com", + } } dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options) # Read data from Azure azure_storage_options={ - "account_url": f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net", - "credential": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] + "account_name": "azure_account_name", + "account_key": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] } dataset = ld.StreamingDataset("azure://my-bucket/my-data", storage_options=azure_storage_options) ``` +- For more details on which storage options are supported, please refer to: + - [AWS S3 storage options](https://github.com/fsspec/s3fs/blob/main/s3fs/core.py#L176) + - [GCS storage options](https://github.com/fsspec/gcsfs/blob/main/gcsfs/core.py#L154) + - [Azure storage options](https://github.com/fsspec/adlfs/blob/main/adlfs/spec.py#L124) +
diff --git a/requirements.txt b/requirements.txt index 06a629a0..ec443722 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,7 @@ torch lightning-utilities filelock numpy -boto3 +# boto3 requests +fsspec +fsspec[s3] # aws s3 diff --git a/requirements/extras.txt b/requirements/extras.txt index 385e2e81..33d42446 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -5,3 +5,5 @@ pyarrow tqdm lightning-sdk ==0.1.17 # Must be pinned to ensure compatibility google-cloud-storage +fsspec[gs] # google cloud storage +fsspec[abfs] # azure blob diff --git a/src/litdata/constants.py b/src/litdata/constants.py index a6a714c7..efe2e248 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -85,3 +85,4 @@ _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" _IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None)) _ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0"))) +_SUPPORTED_CLOUD_PROVIDERS = ["s3", "gs", "azure", "abfs"] diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index f1af9afa..fae806b5 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -32,8 +32,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse -import boto3 -import botocore import numpy as np import torch @@ -42,14 +40,21 @@ _ENABLE_STATUS, _INDEX_FILENAME, _IS_IN_STUDIO, + _SUPPORTED_CLOUD_PROVIDERS, _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader -from litdata.processing.utilities import _create_dataset, download_directory_from_S3, remove_uuid_from_filename +from litdata.processing.utilities import _create_dataset, remove_uuid_from_filename from litdata.streaming import Cache from litdata.streaming.cache import Dir -from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.downloader import ( + does_file_exist, + download_file_or_directory, + get_cloud_provider, + remove_file_or_directory, + upload_file_or_directory, +) from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import _resolve_dir from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads @@ -96,14 +101,22 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any: - """This function check.""" +def _wait_for_file_to_exist( + remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5, storage_options: Optional[Dict] = {} +) -> Any: + """This function check if a file exists on the remote storage. + + If not, it waits for a while and tries again. + + """ + cloud_provider = get_cloud_provider(remote_filepath) while True: try: - return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/")) - except botocore.exceptions.ClientError as e: - if "the HeadObject operation: Not Found" in str(e): + return does_file_exist(remote_filepath, cloud_provider, storage_options=storage_options) + except Exception as e: + if wait_for_count > 0: sleep(sleep_time) + wait_for_count -= 1 else: raise e @@ -118,10 +131,10 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: return -def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: - """Download data from a remote directory to a cache directory to optimise reading.""" - s3 = S3Client() - +def _download_data_target( + input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Optional[Dict] = {} +) -> None: + """This function is used to download data from a remote directory to a cache directory to optimise reading.""" while True: # 2. Fetch from the queue r: Optional[Tuple[int, List[str]]] = queue_in.get() @@ -156,13 +169,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue obj = parse.urlparse(path) - if obj.scheme == "s3": + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: dirpath = os.path.dirname(local_path) os.makedirs(dirpath, exist_ok=True) - - with open(local_path, "wb") as f: - s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + download_file_or_directory(path, local_path, storage_options=storage_options) elif os.path.isfile(path): if not path.startswith("/teamspace/studios/this_studio"): @@ -198,12 +209,13 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: os.remove(path) -def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: - """Upload optimised chunks from a local to remote dataset directory.""" +def _upload_fn( + upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir, storage_options: Optional[Dict] = {} +) -> None: + """This function is used to upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) - if obj.scheme == "s3": - s3 = S3Client() + is_remote = obj.scheme in _SUPPORTED_CLOUD_PROVIDERS while True: data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get() @@ -223,7 +235,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ if not local_filepath.startswith(cache_dir): local_filepath = os.path.join(cache_dir, local_filepath) - if obj.scheme == "s3": + if is_remote: try: output_filepath = str(obj.path).lstrip("/") @@ -235,12 +247,8 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = os.path.join(output_filepath, local_filepath.replace(tmpdir, "")[1:]) output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints - - s3.client.upload_file( - local_filepath, - obj.netloc, - output_filepath, - ) + remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + output_filepath + upload_file_or_directory(local_filepath, remote_filepath, storage_options=storage_options) except Exception as e: print(e) @@ -417,6 +425,7 @@ def __init__( checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None, checkpoint_next_index: Optional[int] = None, item_loader: Optional[BaseItemLoader] = None, + storage_options: Optional[Dict] = {}, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -451,6 +460,7 @@ def __init__( self.use_checkpoint: bool = use_checkpoint self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info self.checkpoint_next_index: Optional[int] = checkpoint_next_index + self.storage_options = storage_options def run(self) -> None: try: @@ -627,6 +637,7 @@ def _start_downloaders(self) -> None: self.cache_data_dir, to_download_queue, self.ready_to_process_queue, + self.storage_options, ), ) p.start() @@ -666,6 +677,7 @@ def _start_uploaders(self) -> None: self.remove_queue, self.cache_chunks_dir, self.output_dir, + self.storage_options, ), ) p.start() @@ -767,6 +779,7 @@ def __init__( chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, encryption: Optional[Encryption] = None, + storage_options: Optional[Dict] = {}, ): super().__init__() if chunk_size is not None and chunk_bytes is not None: @@ -776,6 +789,7 @@ def __init__( self.chunk_bytes = 1 << 26 if chunk_size is None and chunk_bytes is None else chunk_bytes self.compression = compression self.encryption = encryption + self.storage_options = storage_options @abstractmethod def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @@ -842,10 +856,12 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra else: local_filepath = os.path.join(cache_dir, _INDEX_FILENAME) - if obj.scheme == "s3": - s3 = S3Client() - s3.client.upload_file( - local_filepath, obj.netloc, os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)) + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + upload_file_or_directory( + local_filepath, + remote_filepath + os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)), + storage_options=self.storage_options, ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) @@ -863,11 +879,13 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra assert output_dir_path remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) - if obj.scheme == "s3": - obj = parse.urlparse(remote_filepath) - _wait_for_file_to_exist(s3, obj) - with open(node_index_filepath, "wb") as f: - s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + _wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options) + download_file_or_directory( + remote_filepath, + node_index_filepath, + storage_options=self.storage_options, + ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(remote_filepath, node_index_filepath) @@ -908,6 +926,7 @@ def __init__( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Optional[Dict] = {}, ): """Provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -932,6 +951,7 @@ def __init__( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: The storage options used by the cloud provider. """ # spawn doesn't work in IPython @@ -968,6 +988,7 @@ def __init__( self.item_loader = item_loader self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} + self.storage_options = storage_options if self.reader is not None and self.weights is not None: raise ValueError("Either the reader or the weights needs to be defined.") @@ -1119,7 +1140,11 @@ def run(self, data_recipe: DataRecipe) -> None: # Exit early if all the workers are done. # This means there were some kinda of errors. if all(not w.is_alive() for w in self.workers): - raise RuntimeError("One of the worker has failed") + try: + error = self.error_queue.get(timeout=0.001) + self._exit_on_error(error) + except Empty: + break if _TQDM_AVAILABLE: pbar.close() @@ -1186,6 +1211,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None, self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None, self.item_loader, + storage_options=self.storage_options, ) worker.start() workers.append(worker) @@ -1237,21 +1263,14 @@ def _cleanup_checkpoints(self) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" - - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc - s3 = boto3.resource("s3") - - checkpoint_prefix = os.path.join(prefix, ".checkpoints") - - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=checkpoint_prefix): - s3.Object(bucket_name, obj.key).delete() + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) + with suppress(FileNotFoundError): + remove_file_or_directory( + os.path.join(self.output_dir.url, ".checkpoints"), storage_options=self.storage_options + ) def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: if not self.use_checkpoint: @@ -1277,24 +1296,20 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - # TODO: Add support for all cloud providers - - s3 = S3Client() - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) # write config.json file to temp directory and upload it to s3 with tempfile.TemporaryDirectory() as temp_dir: temp_file_name = os.path.join(temp_dir, "config.json") with open(temp_file_name, "w") as f: json.dump(config, f) - s3.client.upload_file( + upload_file_or_directory( temp_file_name, - obj.netloc, - os.path.join(prefix, "config.json"), + os.path.join(self.output_dir.url, ".checkpoints", "config.json"), + storage_options=self.storage_options, ) except Exception as e: print(e) @@ -1345,26 +1360,25 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - # TODO: Add support for all cloud providers - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" - - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) # download all the checkpoint files in tempdir and read them with tempfile.TemporaryDirectory() as temp_dir: - saved_file_dir = download_directory_from_S3(bucket_name, prefix, temp_dir) - - if not os.path.exists(os.path.join(saved_file_dir, "config.json")): + try: + download_file_or_directory( + os.path.join(self.output_dir.url, ".checkpoints/"), temp_dir, storage_options=self.storage_options + ) + except FileNotFoundError: + return + if not os.path.exists(os.path.join(temp_dir, "config.json")): # if the config.json file doesn't exist, we don't have any checkpoint saved return # read the config.json file - with open(os.path.join(saved_file_dir, "config.json")) as f: + with open(os.path.join(temp_dir, "config.json")) as f: config = json.load(f) if config["num_workers"] != self.num_workers: @@ -1378,11 +1392,11 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: checkpoint_file_names = [f"checkpoint-{worker_idx}.json" for worker_idx in range(self.num_workers)] for i, checkpoint_file_name in enumerate(checkpoint_file_names): - if not os.path.exists(os.path.join(saved_file_dir, checkpoint_file_name)): + if not os.path.exists(os.path.join(temp_dir, checkpoint_file_name)): # if the checkpoint file doesn't exist, we don't have any checkpoint saved for this worker continue - with open(os.path.join(saved_file_dir, checkpoint_file_name)) as f: + with open(os.path.join(temp_dir, checkpoint_file_name)) as f: checkpoint = json.load(f) self.checkpoint_chunks_info[i] = checkpoint["chunks"] diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index dd62909a..e83c8c1b 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -27,7 +27,7 @@ import torch -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -36,8 +36,8 @@ optimize_dns_context, read_index_file_content, ) -from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.downloader import copy_file_or_directory, upload_file_or_directory from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import ( Dir, @@ -53,7 +53,7 @@ def _is_remote_file(path: str) -> bool: obj = parse.urlparse(path) - return obj.scheme in ["s3", "gcs"] + return obj.scheme in _SUPPORTED_CLOUD_PROVIDERS def _get_indexed_paths(data: Any) -> Dict[int, str]: @@ -151,8 +151,15 @@ def __init__( compression: Optional[str], encryption: Optional[Encryption] = None, existing_index: Optional[Dict[str, Any]] = None, + storage_options: Optional[Dict] = {}, ): - super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption) + super().__init__( + chunk_size=chunk_size, + chunk_bytes=chunk_bytes, + compression=compression, + encryption=encryption, + storage_options=storage_options, + ) self._fn = fn self._inputs = inputs self.is_generator = False @@ -199,6 +206,7 @@ def map( error_when_not_empty: bool = False, reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, + storage_options: Optional[Dict] = {}, ) -> None: """Maps a callable over a collection of inputs, possibly in a distributed way. @@ -220,6 +228,7 @@ def map( error_when_not_empty: Whether we should error if the output folder isn't empty. reader: The reader to use when reading the data. By default, it uses the `BaseReader`. batch_size: Group the inputs into batches of batch_size length. + storage_options: The storage options used by the cloud provider. """ if isinstance(inputs, StreamingDataLoader) and batch_size is not None: @@ -258,7 +267,7 @@ def map( ) if error_when_not_empty: - _assert_dir_is_empty(_output_dir) + _assert_dir_is_empty(_output_dir, storage_options=storage_options) if not isinstance(inputs, StreamingDataLoader): input_dir = input_dir or _get_input_dir(inputs) @@ -282,6 +291,7 @@ def map( reorder_files=reorder_files, weights=weights, reader=reader, + storage_options=storage_options, ) with optimize_dns_context(True): return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) @@ -315,6 +325,7 @@ def optimize( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Optional[Dict] = {}, ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. @@ -349,6 +360,7 @@ def optimize( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: The storage options used by the cloud provider. """ if mode is not None and mode not in ["append", "overwrite"]: @@ -403,7 +415,9 @@ def optimize( "\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint) + _assert_dir_has_index_file( + _output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options + ) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -419,7 +433,9 @@ def optimize( num_workers = num_workers or _get_default_num_workers() state_dict = {rank: 0 for rank in range(num_workers)} - existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None + existing_index_file_content = ( + read_index_file_content(_output_dir, storage_options=storage_options) if mode == "append" else None + ) if existing_index_file_content is not None: for chunk in existing_index_file_content["chunks"]: @@ -441,6 +457,7 @@ def optimize( use_checkpoint=use_checkpoint, item_loader=item_loader, start_method=start_method, + storage_options=storage_options, ) with optimize_dns_context(True): @@ -453,6 +470,7 @@ def optimize( compression=compression, encryption=encryption, existing_index=existing_index_file_content, + storage_options=storage_options, ) ) return None @@ -521,12 +539,14 @@ class CopyInfo: new_filename: str -def merge_datasets(input_dirs: List[str], output_dir: str) -> None: - """Enables to merge multiple existing optimized datasets into a single optimized dataset. +def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Optional[Dict] = {}) -> None: + """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized + dataset. Args: input_dirs: A list of directories pointing to the existing optimized datasets. output_dir: The directory where the merged dataset would be stored. + storage_options: A dictionary of storage options to be passed to the fsspec library. """ if len(input_dirs) == 0: @@ -541,12 +561,14 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") - input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + input_dirs_file_content = [ + read_index_file_content(input_dir, storage_options=storage_options) for input_dir in resolved_input_dirs + ] if any(file_content is None for file_content in input_dirs_file_content): raise ValueError("One of the provided input_dir doesn't have an index file.") - output_dir_file_content = read_index_file_content(resolved_output_dir) + output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options=storage_options) if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") @@ -581,12 +603,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: _tqdm = _get_tqdm_iterator_if_available() for copy_info in _tqdm(copy_infos): - _apply_copy(copy_info, resolved_output_dir) + _apply_copy(copy_info, resolved_output_dir, storage_options=storage_options) - _save_index(index_json, resolved_output_dir) + _save_index(index_json, resolved_output_dir, storage_options=storage_options) -def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: +def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: if output_dir.url is None and copy_info.input_dir.url is None: assert copy_info.input_dir.path assert output_dir.path @@ -596,20 +618,15 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: shutil.copyfile(input_filepath, output_filepath) elif output_dir.url and copy_info.input_dir.url: - input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename)) - output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename)) - - s3 = S3Client() - s3.client.copy( - {"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")}, - output_obj.netloc, - output_obj.path.lstrip("/"), - ) + input_obj = os.path.join(copy_info.input_dir.url, copy_info.old_filename) + output_obj = os.path.join(output_dir.url, copy_info.new_filename) + + copy_file_or_directory(input_obj, output_obj, storage_options=storage_options) else: raise NotImplementedError -def _save_index(index_json: Dict, output_dir: Dir) -> None: +def _save_index(index_json: Dict, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: if output_dir.url is None: assert output_dir.path with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: @@ -620,11 +637,6 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None: f.flush() - obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME)) - - s3 = S3Client() - s3.client.upload_file( - f.name, - obj.netloc, - obj.path.lstrip("/"), + upload_file_or_directory( + f.name, os.path.join(output_dir.url, _INDEX_FILENAME), storage_options=storage_options ) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index a50b9652..a13e863d 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -21,11 +21,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -import boto3 -import botocore - -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS from litdata.streaming.cache import Dir +from litdata.streaming.downloader import download_file_or_directory def _create_dataset( @@ -183,7 +181,7 @@ def _get_work_dir() -> str: return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" -def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: +def read_index_file_content(output_dir: Dir, storage_options: Optional[Dict] = {}) -> Optional[Dict[str, Any]]: """Read the index file content.""" if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir should be a Dir object.") @@ -201,27 +199,26 @@ def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: # download the index file from s3, and read it obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.") - - # TODO: Add support for all cloud providers - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.path}." + ) # Check the index file exists try: # Create a temporary file with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file: temp_file_name = temp_file.name - s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name) + download_file_or_directory( + os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name, storage_options=storage_options + ) # Read data from the temporary file with open(temp_file_name) as temp_file: data = json.load(temp_file) # Delete the temporary file os.remove(temp_file_name) return data - except botocore.exceptions.ClientError: + except Exception as _e: return None @@ -256,21 +253,3 @@ def remove_uuid_from_filename(filepath: str) -> str: # uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character return filepath[:-38] + ".json" - - -def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str: - s3_resource = boto3.resource("s3") - bucket = s3_resource.Bucket(bucket_name) - - saved_file_dir = "." - - for obj in bucket.objects.filter(Prefix=remote_directory_name): - local_filename = os.path.join(local_directory_name, obj.key) - - if not os.path.exists(os.path.dirname(local_filename)): - os.makedirs(os.path.dirname(local_filename)) - with open(local_filename, "wb") as f: - s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f) - saved_file_dir = os.path.dirname(local_filename) - - return saved_file_dir diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py deleted file mode 100644 index d24803c3..00000000 --- a/src/litdata/streaming/client.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from time import time -from typing import Any, Dict, Optional - -import boto3 -import botocore -from botocore.credentials import InstanceMetadataProvider -from botocore.utils import InstanceMetadataFetcher - -from litdata.constants import _IS_IN_STUDIO - - -class S3Client: - # TODO: Generalize to support more cloud providers. - - def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: - self._refetch_interval = refetch_interval - self._last_time: Optional[float] = None - self._client: Optional[Any] = None - self._storage_options: dict = storage_options or {} - - def _create_client(self) -> None: - has_shared_credentials_file = ( - os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" - ) - - if has_shared_credentials_file or not _IS_IN_STUDIO: - self._client = boto3.client( - "s3", - **{ - "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), - **self._storage_options, - }, - ) - else: - provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) - credentials = provider.load() - self._client = boto3.client( - "s3", - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), - ) - - @property - def client(self) -> Any: - if self._client is None: - self._create_client() - self._last_time = time() - - # Re-generate credentials for EC2 - if self._last_time is None or (time() - self._last_time) > self._refetch_interval: - self._create_client() - self._last_time = time() - - return self._client diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index ea82ce7a..a6f70bf5 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -155,7 +155,8 @@ def set_epoch(self, current_epoch: int) -> None: def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( - input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url + input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, + storage_options=self.storage_options, ) if cache_path is not None: self.input_dir.path = cache_path @@ -438,7 +439,8 @@ def _validate_state_dict(self) -> None: # In this case, validate the cache folder is the same. if _should_replace_path(state["input_dir_path"]): cache_path = _try_create_cache_dir( - input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"] + input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"], + storage_options=self.storage_options, ) if cache_path != self.input_dir.path: raise ValueError( diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 41e4a6a9..463ab576 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -16,24 +16,32 @@ import shutil import subprocess from abc import ABC -from contextlib import suppress -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from urllib import parse +import fsspec from filelock import FileLock, Timeout -from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME -from litdata.streaming.client import S3Client +from litdata.constants import _INDEX_FILENAME + +# from litdata.streaming.client import S3Client + +_USE_S5CMD_FOR_S3 = True class Downloader(ABC): def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + self, + cloud_provider: str, + remote_dir: str, + cache_dir: str, + chunks: List[Dict[str, Any]], + storage_options: Optional[Dict] = {}, ): self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks - self._storage_options = storage_options or {} + self.fs = fsspec.filesystem(cloud_provider, **storage_options) def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] @@ -45,157 +53,195 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: pass -class S3Downloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - super().__init__(remote_dir, cache_dir, chunks, storage_options) - self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 +class LocalDownloader(Downloader): + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + if not os.path.exists(remote_filepath): + raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") + + try: + with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0): + if remote_filepath != local_filepath and not os.path.exists(local_filepath): + # make an atomic operation to be safe + temp_file_path = local_filepath + ".tmp" + shutil.copy(remote_filepath, temp_file_path) + os.rename(temp_file_path, local_filepath) + with contextlib.suppress(Exception): + os.remove(local_filepath + ".lock") + except Timeout: + pass - if not self._s5cmd_available: - self._client = S3Client(storage_options=self._storage_options) +class LocalDownloaderWithCache(LocalDownloader): def download_file(self, remote_filepath: str, local_filepath: str) -> None: - obj = parse.urlparse(remote_filepath) + remote_filepath = remote_filepath.replace("local:", "") + super().download_file(remote_filepath, local_filepath) - if obj.scheme != "s3": - raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - if os.path.exists(local_filepath): - return +def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: + _s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - with suppress(Timeout), FileLock( - local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0 - ): - if self._s5cmd_available: - proc = subprocess.Popen( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - ) - proc.wait() - else: - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - # try: - # with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) - - -class GCPDownloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - if not _GOOGLE_STORAGE_AVAILABLE: - raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) + if _s5cmd_available is False: + raise ModuleNotFoundError(str(_s5cmd_available)) - super().__init__(remote_dir, cache_dir, chunks, storage_options) + obj = parse.urlparse(remote_filepath) - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - from google.cloud import storage + if obj.scheme != "s3": + raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") - obj = parse.urlparse(remote_filepath) + if os.path.exists(local_filepath): + return - if obj.scheme != "gs": - raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote_filepath}") + try: + with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", + shell=True, + stdout=subprocess.PIPE, + ) + proc.wait() + except Timeout: + # another process is responsible to download that file, continue + pass - if os.path.exists(local_filepath): - return - with suppress(Timeout), FileLock( - local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0 - ): - bucket_name = obj.netloc - key = obj.path - # Remove the leading "/": - if key[0] == "/": - key = key[1:] +_DOWNLOADERS = { + "s3://": "s3", + "gs://": "gs", + "azure://": "abfs", + "abfs://": "abfs", + "local:": "file", + "": "file", +} + +_DEFAULT_STORAGE_OPTIONS = { + "s3": {"config_kwargs": {"retries": {"max_attempts": 1000, "mode": "adaptive"}}}, +} + - client = storage.Client(**self._storage_options) - bucket = client.bucket(bucket_name) - blob = bucket.blob(key) - blob.download_to_filename(local_filepath) +def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict: + if storage_options is None: + storage_options = {} + if cloud_provider in _DEFAULT_STORAGE_OPTIONS: + return {**_DEFAULT_STORAGE_OPTIONS[cloud_provider], **storage_options} + return storage_options -class AzureDownloader(Downloader): +class FsspecDownloader(Downloader): def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + self, + cloud_provider: str, + remote_dir: str, + cache_dir: str, + chunks: List[Dict[str, Any]], + storage_options: Optional[Dict] = {}, ): - if not _AZURE_STORAGE_AVAILABLE: - raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) - - super().__init__(remote_dir, cache_dir, chunks, storage_options) + remote_dir = remote_dir.replace("local:", "") + self.is_local = False + storage_options = get_complete_storage_options(cloud_provider, storage_options) + super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) + self.cloud_provider = cloud_provider + self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 def download_file(self, remote_filepath: str, local_filepath: str) -> None: - from azure.storage.blob import BlobServiceClient - - obj = parse.urlparse(remote_filepath) - - if obj.scheme != "azure": - raise ValueError( - f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}" - ) - - if os.path.exists(local_filepath): + if os.path.exists(local_filepath) or remote_filepath == local_filepath: + return + if self.use_s5cmd and _USE_S5CMD_FOR_S3: + download_s3_file_via_s5cmd(remote_filepath, local_filepath) return + try: + with FileLock(local_filepath + ".lock", timeout=3): + self.fs.get(remote_filepath, local_filepath, recursive=True) + # remove the lock file + if os.path.exists(local_filepath + ".lock"): + os.remove(local_filepath + ".lock") + except Timeout: + # another process is responsible to download that file, continue + pass + + +def does_file_exist( + remote_filepath: str, cloud_provider: Union[str, None] = None, storage_options: Optional[Dict] = {} +) -> bool: + if cloud_provider is None: + cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(cloud_provider, storage_options) + fs = fsspec.filesystem(cloud_provider, **storage_options) + return fs.exists(remote_filepath) + + +def list_directory( + remote_directory: str, + detail: bool = False, + cloud_provider: Optional[str] = None, + storage_options: Optional[Dict] = {}, +) -> List[str]: + """Returns a list of filenames in a remote directory.""" + if cloud_provider is None: + cloud_provider = get_cloud_provider(remote_directory) + storage_options = get_complete_storage_options(cloud_provider, storage_options) + fs = fsspec.filesystem(cloud_provider, **storage_options) + return fs.ls(remote_directory, detail=detail) # just return the filenames + + +def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """Download a file from the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + use_s5cmd = fs_cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 + if use_s5cmd and _USE_S5CMD_FOR_S3: + download_s3_file_via_s5cmd(remote_filepath, local_filepath) + return + try: + with FileLock(local_filepath + ".lock", timeout=3): + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.get(remote_filepath, local_filepath, recursive=True) + except Timeout: + # another process is responsible to download that file, continue + pass - with suppress(Timeout), FileLock( - local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0 - ): - service = BlobServiceClient(**self._storage_options) - blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/")) - with open(local_filepath, "wb") as download_file: - blob_data = blob_client.download_blob() - blob_data.readinto(download_file) +def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """Upload a file to the remote cloud storage.""" + try: + with FileLock(local_filepath + ".lock", timeout=3): + fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.put(local_filepath, remote_filepath, recursive=True) + except Timeout: + # another process is responsible to upload that file, continue + pass -class LocalDownloader(Downloader): - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - if not os.path.exists(remote_filepath): - raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") - with suppress(Timeout), FileLock( - local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0 - ): - if remote_filepath == local_filepath or os.path.exists(local_filepath): - return - # make an atomic operation to be safe - temp_file_path = local_filepath + ".tmp" - shutil.copy(remote_filepath, temp_file_path) - os.rename(temp_file_path, local_filepath) - with contextlib.suppress(Exception): - os.remove(local_filepath + ".lock") +def copy_file_or_directory( + remote_filepath_src: str, remote_filepath_tg: str, storage_options: Optional[Dict] = {} +) -> None: + """Copy a file from src to target on the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath_src) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) -class LocalDownloaderWithCache(LocalDownloader): - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - remote_filepath = remote_filepath.replace("local:", "") - super().download_file(remote_filepath, local_filepath) +def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """Remove a file from the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.rm(remote_filepath, recursive=True) -_DOWNLOADERS = { - "s3://": S3Downloader, - "gs://": GCPDownloader, - "azure://": AzureDownloader, - "local:": LocalDownloaderWithCache, - "": LocalDownloader, -} +def get_cloud_provider(remote_filepath: str) -> str: + for k, fs_cloud_provider in _DOWNLOADERS.items(): + if str(remote_filepath).startswith(k): + return fs_cloud_provider + raise ValueError(f"The provided `remote_filepath` {remote_filepath} doesn't have a downloader associated.") def get_downloader_cls( remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ) -> Downloader: - for k, cls in _DOWNLOADERS.items(): + for k, fs_cloud_provider in _DOWNLOADERS.items(): if str(remote_dir).startswith(k): - return cls(remote_dir, cache_dir, chunks, storage_options) + return FsspecDownloader(fs_cloud_provider, remote_dir, cache_dir, chunks, storage_options) raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 98ce5fef..a1781ccc 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,13 +20,15 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from urllib import parse -import boto3 -import botocore - -from litdata.constants import _LIGHTNING_SDK_AVAILABLE +from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_CLOUD_PROVIDERS +from litdata.streaming.downloader import ( + does_file_exist, + list_directory, + remove_file_or_directory, +) if TYPE_CHECKING: from lightning_sdk import Machine @@ -52,9 +54,15 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: assert isinstance(dir_path, str) - cloud_prefixes = ("s3://", "gs://", "azure://") - if dir_path.startswith(cloud_prefixes): - return Dir(path=None, url=dir_path) + cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS + dir_scheme = parse.urlparse(dir_path).scheme + if bool(dir_scheme) and dir_scheme not in ["c", "d", "e", "f"]: # prevent windows `c:\\` and `d:\\` + if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): + return Dir(path=None, url=dir_path) + raise ValueError( + f"The provided dir_path `{dir_path}` is not supported.", + f" HINT: Only the following cloud providers are supported: {_SUPPORTED_CLOUD_PROVIDERS}.", + ) if dir_path.startswith("local:"): return Dir(path=None, url=dir_path) @@ -88,14 +96,11 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa if target_id is not None and cloudspace.id == target_id: return True - if ( + return bool( cloudspace.display_name is not None and target_name is not None and cloudspace.display_name.lower() == target_name.lower() - ): - return True - - return False + ) def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: @@ -209,7 +214,9 @@ def _resolve_datasets(dir_path: str) -> Dir: ) -def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool = False) -> None: +def _assert_dir_is_empty( + output_dir: Dir, append: bool = False, overwrite: bool = False, storage_options: Optional[Dict] = {} +) -> None: if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir isn't a `Dir` Object.") @@ -218,20 +225,16 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") - - s3 = boto3.client("s3") + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") - objects = s3.list_objects_v2( - Bucket=obj.netloc, - Delimiter="/", - Prefix=obj.path.lstrip("/").rstrip("/") + "/", - ) + try: + object_list = list_directory(output_dir.url, storage_options=storage_options) + except FileNotFoundError: + return # We aren't alloweing to add more data - # TODO: Add support for `append` and `overwrite`. - if objects["KeyCount"] > 0: + if object_list is not None and len(object_list) > 0: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains data and datasets are meant to be immutable." "\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" @@ -239,7 +242,10 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool def _assert_dir_has_index_file( - output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False + output_dir: Dir, + mode: Optional[Literal["append", "overwrite"]] = None, + use_checkpoint: bool = False, + storage_options: Optional[Dict] = {}, ) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") @@ -283,29 +289,19 @@ def _assert_dir_has_index_file( obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") - - s3 = boto3.client("s3") + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") - prefix = obj.path.lstrip("/").rstrip("/") + "/" - - objects = s3.list_objects_v2( - Bucket=obj.netloc, - Delimiter="/", - Prefix=prefix, - ) + objects_list = [] + with suppress(FileNotFoundError): + objects_list = list_directory(output_dir.url, storage_options=storage_options) # No files are found in this folder - if objects["KeyCount"] == 0: + if objects_list is None or len(objects_list) == 0: return # Check the index file exists - try: - s3.head_object(Bucket=obj.netloc, Key=os.path.join(prefix, "index.json")) - has_index_file = True - except botocore.exceptions.ClientError: - has_index_file = False + has_index_file = does_file_exist(os.path.join(output_dir.url, "index.json"), storage_options=storage_options) if has_index_file and mode is None: raise RuntimeError( @@ -314,13 +310,8 @@ def _assert_dir_has_index_file( "\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc - s3 = boto3.resource("s3") - if mode == "overwrite" or (mode is None and not use_checkpoint): - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): - s3.Object(bucket_name, obj.key).delete() + remove_file_or_directory(output_dir.url, storage_options=storage_options) def _get_lightning_cloud_url() -> str: diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index cc3ef9b2..ef111541 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -109,20 +109,16 @@ def fn(*_, **__): remove_queue = mock.MagicMock() - s3_client = mock.MagicMock() - called = False - def copy_file(local_filepath, *args): + def copy_file(local_filepath, *args, **kwargs): nonlocal called called = True from shutil import copyfile copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath))) - s3_client.client.upload_file = copy_file - - monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client)) + monkeypatch.setattr(data_processor_module, "upload_file_or_directory", copy_file) assert os.listdir(remote_output_dir) == [] @@ -217,32 +213,28 @@ def test_wait_for_disk_usage_higher_than_threshold(): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -def test_wait_for_file_to_exist(): - import botocore - - s3 = mock.MagicMock() - obj = mock.MagicMock() +def test_wait_for_file_to_exist(monkeypatch): raise_error = [True, True, False] def fn(*_, **__): value = raise_error.pop(0) if value: - raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception return - s3.client.head_object = fn + monkeypatch.setattr(data_processor_module, "does_file_exist", fn) - _wait_for_file_to_exist(s3, obj, sleep_time=0.01) + _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) assert len(raise_error) == 0 def fn(*_, **__): raise ValueError("HERE") - s3.client.head_object = fn + monkeypatch.setattr(data_processor_module, "does_file_exist", fn) with pytest.raises(ValueError, match="HERE"): - _wait_for_file_to_exist(s3, obj, sleep_time=0.01) + _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) def test_cache_dir_cleanup(tmpdir, monkeypatch): @@ -1024,11 +1016,10 @@ def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_map_error_when_not_empty(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - monkeypatch.setattr(resolver, "boto3", boto3) + def mock_list_directory(*args, **kwargs): + return ["a.txt", "b.txt"] + + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="data and datasets are meant to be immutable"): map( diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py deleted file mode 100644 index 78ea919d..00000000 --- a/tests/streaming/test_client.py +++ /dev/null @@ -1,97 +0,0 @@ -import sys -from time import sleep, time -from unittest import mock - -import pytest -from litdata.streaming import client - - -def test_s3_client_with_storage_options(monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) - - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) - - storage_options = { - "region_name": "us-west-2", - "endpoint_url": "https://custom.endpoint", - "config": botocore.config.Config(retries={"max_attempts": 100}), - } - s3_client = client.S3Client(storage_options=storage_options) - - assert s3_client.client - - boto3.client.assert_called_with( - "s3", - region_name="us-west-2", - endpoint_url="https://custom.endpoint", - config=botocore.config.Config(retries={"max_attempts": 100}), - ) - - s3_client = client.S3Client() - - assert s3_client.client - - boto3.client.assert_called_with( - "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) - ) - - -def test_s3_client_without_cloud_space_id(monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) - - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) - - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - - s3 = client.S3Client(1) - assert s3.client - assert s3.client - assert s3.client - assert s3.client - assert s3.client - - boto3.client.assert_called_once() - - -@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") -@pytest.mark.parametrize("use_shared_credentials", [False, True, None]) -def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) - - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) - - if isinstance(use_shared_credentials, bool): - monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") - monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") - monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") - - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - - s3 = client.S3Client(1) - assert s3.client - assert s3.client - boto3.client.assert_called_once() - sleep(1 - (time() - s3._last_time)) - assert s3.client - assert s3.client - assert len(boto3.client._mock_mock_calls) == 6 - sleep(1 - (time() - s3._last_time)) - assert s3.client - assert s3.client - assert len(boto3.client._mock_mock_calls) == 9 - - assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3 diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index ef93021c..b87f9ffd 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -889,7 +889,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) -@pytest.mark.timeout(60) +@pytest.mark.timeout(120) @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 7c79afe5..97368c0c 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -1,84 +1,19 @@ import os -from unittest import mock from unittest.mock import MagicMock from litdata.streaming.downloader import ( - AzureDownloader, - GCPDownloader, LocalDownloaderWithCache, - S3Downloader, shutil, - subprocess, ) -def test_s3_downloader_fast(tmpdir, monkeypatch): - monkeypatch.setattr(os, "system", MagicMock(return_value=0)) - popen_mock = MagicMock() - monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) - downloader = S3Downloader(tmpdir, tmpdir, []) - downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) - popen_mock.wait.assert_called() - - -@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True) -def test_gcp_downloader(tmpdir, monkeypatch, google_mock): - # Create mock objects - mock_client = MagicMock() - mock_bucket = MagicMock() - mock_blob = MagicMock() - mock_blob.download_to_filename = MagicMock() - - # Patch the storage client to return the mock client - google_mock.cloud.storage.Client = MagicMock(return_value=mock_client) - - # Configure the mock client to return the mock bucket and blob - mock_client.bucket = MagicMock(return_value=mock_bucket) - mock_bucket.blob = MagicMock(return_value=mock_blob) - - # Initialize the downloader - storage_options = {"project": "DUMMY_PROJECT"} - downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options) - local_filepath = os.path.join(tmpdir, "a.txt") - downloader.download_file("gs://random_bucket/a.txt", local_filepath) - - # Assert that the correct methods were called - google_mock.cloud.storage.Client.assert_called_with(**storage_options) - mock_client.bucket.assert_called_with("random_bucket") - mock_bucket.blob.assert_called_with("a.txt") - mock_blob.download_to_filename.assert_called_with(local_filepath) - - -@mock.patch("litdata.streaming.downloader._AZURE_STORAGE_AVAILABLE", True) -def test_azure_downloader(tmpdir, monkeypatch, azure_mock): - mock_blob = MagicMock() - mock_blob_data = MagicMock() - mock_blob.download_blob.return_value = mock_blob_data - service_mock = MagicMock() - service_mock.get_blob_client.return_value = mock_blob - - azure_mock.storage.blob.BlobServiceClient = MagicMock(return_value=service_mock) - - # Initialize the downloader - storage_options = {"project": "DUMMY_PROJECT"} - downloader = AzureDownloader("azure://random_bucket", tmpdir, [], storage_options) - local_filepath = os.path.join(tmpdir, "a.txt") - downloader.download_file("azure://random_bucket/a.txt", local_filepath) - - # Assert that the correct methods were called - azure_mock.storage.blob.BlobServiceClient.assert_called_with(**storage_options) - service_mock.get_blob_client.assert_called_with(container="random_bucket", blob="a.txt") - mock_blob.download_blob.assert_called() - mock_blob_data.readinto.assert_called() - - def test_download_with_cache(tmpdir, monkeypatch): # Create a file to download/cache with open("a.txt", "w") as f: f.write("hello") try: - local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, []) + local_downloader = LocalDownloaderWithCache("file", tmpdir, tmpdir, []) shutil_mock = MagicMock() os_mock = MagicMock() monkeypatch.setattr(shutil, "copy", shutil_mock) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 90729ffb..699a39c4 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -302,52 +302,54 @@ def print_fn(msg, file=None): def test_assert_dir_is_empty(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory(*args, **kwargs): + return ["a.txt", "b.txt"] + + def mock_empty_list_directory(*args, **kwargs): + return [] + + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="The provided output_dir"): resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + monkeypatch.setattr(resolver, "list_directory", mock_empty_list_directory) resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) def test_assert_dir_has_index_file(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory_0(*args, **kwargs): + return [] - with pytest.raises(RuntimeError, match="The provided output_dir"): - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + def mock_list_directory_1(*args, **kwargs): + return ["a.txt", "b.txt"] - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory_2(*args, **kwargs): + return ["index.json"] - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + def mock_does_file_exist_1(*args, **kwargs): + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} + def mock_does_file_exist_2(*args, **kwargs): + return True - def head_object(*args, **kwargs): - import botocore + def mock_remove_file_or_directory(*args, **kwargs): + return - raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") - - client_s3_mock.head_object = head_object - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + monkeypatch.setattr(resolver, "list_directory", mock_list_directory_0) + monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_1) + monkeypatch.setattr(resolver, "remove_file_or_directory", mock_remove_file_or_directory) resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - boto3.resource.assert_called() + monkeypatch.setattr(resolver, "list_directory", mock_list_directory_2) + monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_2) + + with pytest.raises(RuntimeError, match="The provided output_dir"): + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode="overwrite") def test_resolve_dir_absolute(tmp_path, monkeypatch): @@ -367,3 +369,10 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch): link.symlink_to(src) assert link.resolve() == src assert resolver._resolve_dir(str(link)).path == str(src) + + +def test_resolve_dir_unsupported_cloud_provider(monkeypatch, tmp_path): + """Test that the unsupported cloud provider is handled correctly.""" + test_dir = "some-random-cloud-provider://some-random-bucket" + with pytest.raises(ValueError, match="The provided dir_path"): + resolver._resolve_dir(test_dir)