Skip to content

Commit

Permalink
feat: pass storage options to s5cmd (#397)
Browse files Browse the repository at this point in the history
* pass storage options to s5cmd

* Update src/litdata/streaming/downloader.py

Co-authored-by: Deependu Jha <[email protected]>

* add mock tests to test_s3_downloader_with_s5cmd

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix:consistent bucket name over the tests

* adds ignore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reverted type:ignore

---------

Co-authored-by: Deependu Jha <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 14, 2024
1 parent 3f47b5e commit 7264fce
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,22 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0
):
if self._s5cmd_available:
env = None
if self._storage_options:
env = os.environ.copy()
env.update(self._storage_options)
proc = subprocess.Popen(
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
env=env,
)
proc.wait()
else:
from boto3.s3.transfer import TransferConfig

extra_args: Dict[str, Any] = {}

# try:
# with FileLock(local_filepath + ".lock", timeout=1):
if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
Expand Down
58 changes: 57 additions & 1 deletion tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

from litdata.streaming.downloader import (
AzureDownloader,
Expand All @@ -21,6 +21,62 @@ def test_s3_downloader_fast(tmpdir, monkeypatch):
popen_mock.wait.assert_called()


@patch("os.system")
@patch("subprocess.Popen")
def test_s3_downloader_with_s5cmd_no_storage_options(popen_mock, system_mock, tmpdir):
system_mock.return_value = 0 # Simulates s5cmd being available
process_mock = MagicMock()
popen_mock.return_value = process_mock

# Initialize the S3Downloader without storage options
downloader = S3Downloader("s3://random_bucket", str(tmpdir), [])

# Action: Call the download_file method
remote_filepath = "s3://random_bucket/sample_file.txt"
local_filepath = os.path.join(tmpdir, "sample_file.txt")
downloader.download_file(remote_filepath, local_filepath)

# Assertion: Verify subprocess.Popen was called with correct arguments and no env variables
popen_mock.assert_called_once_with(
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
env=None,
)
process_mock.wait.assert_called_once()


@patch("os.system")
@patch("subprocess.Popen")
def test_s3_downloader_with_s5cmd_with_storage_options(popen_mock, system_mock, tmpdir):
system_mock.return_value = 0 # Simulates s5cmd being available
process_mock = MagicMock()
popen_mock.return_value = process_mock

storage_options = {"AWS_ACCESS_KEY_ID": "dummy_key", "AWS_SECRET_ACCESS_KEY": "dummy_secret"}

# Initialize the S3Downloader with storage options
downloader = S3Downloader("s3://random_bucket", str(tmpdir), [], storage_options)

# Action: Call the download_file method
remote_filepath = "s3://random_bucket/sample_file.txt"
local_filepath = os.path.join(tmpdir, "sample_file.txt")
downloader.download_file(remote_filepath, local_filepath)

# Create expected environment variables by merging the current env with storage_options
expected_env = os.environ.copy()
expected_env.update(storage_options)

# Assertion: Verify subprocess.Popen was called with the correct arguments and environment variables
popen_mock.assert_called_once_with(
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
env=expected_env,
)
process_mock.wait.assert_called_once()


@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True)
def test_gcp_downloader(tmpdir, monkeypatch, google_mock):
# Create mock objects
Expand Down

0 comments on commit 7264fce

Please sign in to comment.