Skip to content

Commit

Permalink
[BUG] Add check for protocol S3A (s3 compatible) (#1761)
Browse files Browse the repository at this point in the history
* We missed a check in `daft-io` client that would let `s3a` protocol
use the s3 client.
  • Loading branch information
samster25 authored Jan 5, 2024
1 parent 94bb370 commit 76b1086
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> {
match scheme.as_ref() {
"file" => Ok((SourceType::File, fixed_input)),
"http" | "https" => Ok((SourceType::Http, fixed_input)),
"s3" => Ok((SourceType::S3, fixed_input)),
"s3" | "s3a" => Ok((SourceType::S3, fixed_input)),
"az" | "abfs" => Ok((SourceType::AzureBlob, fixed_input)),
"gcs" | "gs" => Ok((SourceType::GCS, fixed_input)),
#[cfg(target_env = "msvc")]
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,6 @@ def minio_image_data_fixture(minio_io_config, image_data_folder) -> YieldFixture
@pytest.fixture(scope="session")
def small_images_s3_paths() -> list[str]:
"""Paths to small *.jpg files in a public S3 bucket"""
return [f"s3://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)]
return [f"s3://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)] + [
f"s3a://daft-public-data/test_fixtures/small_images/rickroll{i}.jpg" for i in range(6)
]
6 changes: 5 additions & 1 deletion tests/integration/io/parquet/test_reads_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def get_filesystem_from_path(path: str, **kwargs) -> fsspec.AbstractFileSystem:
"parquet-benchmarking/mvp",
"s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet",
),
(
"parquet-benchmarking/s3a-mvp",
"s3a://daft-public-data/test_fixtures/parquet-dev/mvp.parquet",
),
(
"azure/mvp",
"az://public-anonymous/mvp.parquet",
Expand Down Expand Up @@ -198,7 +202,7 @@ def parquet_file(request) -> tuple[str, str]:

def read_parquet_with_pyarrow(path) -> pa.Table:
kwargs = {}
if get_protocol_from_path(path) == "s3":
if get_protocol_from_path(path) == "s3" or get_protocol_from_path(path) == "s3a":
kwargs["anon"] = True
if get_protocol_from_path(path) == "az":
kwargs["account_name"] = "dafttestdata"
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/io/test_list_files_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def s3fs_recursive_list(fs, path) -> list:
[
# Exact filepath:
(f"s3://bucket/a.match", [{"type": "File", "path": "s3://bucket/a.match", "size": 0}]),
# Exact filepath but with s3a:
(f"s3a://bucket/a.match", [{"type": "File", "path": "s3a://bucket/a.match", "size": 0}]),
###
# `**`: recursive wildcard
###
Expand All @@ -55,6 +57,21 @@ def s3fs_recursive_list(fs, path) -> list:
{"type": "File", "path": "s3://bucket/nested2/c.match", "size": 0},
],
),
# All files with s3a and **
(
f"s3a://bucket/**",
[
{"type": "File", "path": "s3a://bucket/a.match", "size": 0},
{"type": "File", "path": "s3a://bucket/b.nomatch", "size": 0},
{"type": "File", "path": "s3a://bucket/c.match", "size": 0},
{"type": "File", "path": "s3a://bucket/nested1/a.match", "size": 0},
{"type": "File", "path": "s3a://bucket/nested1/b.nomatch", "size": 0},
{"type": "File", "path": "s3a://bucket/nested1/c.match", "size": 0},
{"type": "File", "path": "s3a://bucket/nested2/a.match", "size": 0},
{"type": "File", "path": "s3a://bucket/nested2/b.nomatch", "size": 0},
{"type": "File", "path": "s3a://bucket/nested2/c.match", "size": 0},
],
),
# Exact filepath after **
(
f"s3://bucket/**/a.match",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/io/test_url_download_private_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_url_download_aws_s3_public_bucket_with_creds(small_images_s3_paths, io_
df = df.with_column("data", df["urls"].url.download(use_native_downloader=True, io_config=io_config))

data = df.to_pydict()
assert len(data["data"]) == 6
assert len(data["data"]) == 12
for img_bytes in data["data"]:
assert img_bytes is not None

Expand Down
10 changes: 5 additions & 5 deletions tests/integration/io/test_url_download_public_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_url_download_aws_s3_public_bucket_custom_s3fs(small_images_s3_paths):
)

data = df.to_pydict()
assert len(data["data"]) == 6
assert len(data["data"]) == 12
for img_bytes in data["data"]:
assert img_bytes is not None

Expand All @@ -28,7 +28,7 @@ def test_url_download_aws_s3_public_bucket_custom_s3fs_wrong_region(small_images
)

data = df.to_pydict()
assert len(data["data"]) == 6
assert len(data["data"]) == 12
for img_bytes in data["data"]:
assert img_bytes is not None

Expand All @@ -40,7 +40,7 @@ def test_url_download_aws_s3_public_bucket_native_downloader(aws_public_s3_confi
df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True))

data = df.to_pydict()
assert len(data["data"]) == 6
assert len(data["data"]) == 12
for img_bytes in data["data"]:
assert img_bytes is not None

Expand All @@ -54,15 +54,15 @@ def test_url_download_aws_s3_public_bucket_native_downloader_io_thread_change(
df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True))

data = df.to_pydict()
assert len(data["data"]) == 6
assert len(data["data"]) == 12
for img_bytes in data["data"]:
assert img_bytes is not None
daft.io.set_io_pool_num_threads(2)
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True))

data = df.to_pydict()
assert len(data["data"]) == 6
assert len(data["data"]) == 12
for img_bytes in data["data"]:
assert img_bytes is not None

Expand Down

0 comments on commit 76b1086

Please sign in to comment.