Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dwyatte committed Mar 2, 2023
1 parent b429ca2 commit 50f2f64
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,24 +330,24 @@ def _request_with_retry(

def fsspec_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url)
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
if len(paths) > 1:
raise ValueError("HEAD can be called with at most one path but was called with {paths}")
return fs.info(paths[0], timeout=timeout)
raise ValueError(f"HEAD can be called with at most one path but was called with {paths}")
return fs.info(paths[0])


def fsspec_get(url, temp_file, timeout=10.0, desc=None):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url)
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
if len(paths) > 1:
raise ValueError("GET can be called with at most one path but was called with {paths}")
raise ValueError(f"GET can be called with at most one path but was called with {paths}")
callback = fsspec.callbacks.TqdmCallback(
tqdm_kwargs={
"desc": desc or "Downloading",
"disable": logging.is_progress_bar_enabled(),
}
)
fs.get(paths[0], temp_file, timeout=timeout, callback=callback)
fs.get_file(paths[0], temp_file.name, callback=callback)


def ftp_head(url, timeout=10.0):
Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def info(self, path, *args, **kwargs):
out["name"] = out["name"][len(self.local_root_dir) :]
return out

def get_file(self, rpath, lpath, *args, **kwargs):
rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath))
return self._fs.get_file(rpath, lpath, *args, **kwargs)

def cp_file(self, path1, path2, *args, **kwargs):
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
Expand Down
21 changes: 20 additions & 1 deletion tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
fsspec_head,
ftp_get,
ftp_head,
get_from_cache,
http_get,
http_head,
)
Expand All @@ -22,16 +23,25 @@
Text data.
Second line of data."""

FILE_PATH = "file"


@pytest.fixture(scope="session")
def zstd_path(tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "file.zstd"
path = tmp_path_factory.mktemp("data") / FILE_PATH
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture
def mockfs_file(mockfs):
with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f:
f.write(FILE_CONTENT)
return mockfs


@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
Expand Down Expand Up @@ -89,6 +99,15 @@ def test_cached_path_missing_local(tmp_path):
cached_path(missing_file)


def test_get_from_cache_fsspec(mockfs_file):
with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths:
mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH])
output_path = get_from_cache("mock://huggingface.co")
with open(output_path) as f:
output_file_content = f.read()
assert output_file_content == FILE_CONTENT


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_cached_path_offline():
with pytest.raises(OfflineModeIsEnabled):
Expand Down

0 comments on commit 50f2f64

Please sign in to comment.