-
Notifications
You must be signed in to change notification settings - Fork 631
/
Copy pathconftest.py
80 lines (61 loc) · 2.77 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import shutil
from typing import Generator
import pytest
from _pytest.fixtures import SubRequest
import huggingface_hub
from huggingface_hub import constants
from huggingface_hub.utils import SoftTemporaryDirectory, logging
from .testing_utils import set_write_permission_and_retry
@pytest.fixture(autouse=True, scope="function")
def patch_constants(mocker):
with SoftTemporaryDirectory() as cache_dir:
mocker.patch.object(constants, "HF_HOME", cache_dir)
mocker.patch.object(constants, "HF_HUB_CACHE", os.path.join(cache_dir, "hub"))
mocker.patch.object(constants, "HUGGINGFACE_HUB_CACHE", os.path.join(cache_dir, "hub"))
mocker.patch.object(constants, "HF_ASSETS_CACHE", os.path.join(cache_dir, "assets"))
mocker.patch.object(constants, "HF_TOKEN_PATH", os.path.join(cache_dir, "token"))
mocker.patch.object(constants, "HF_STORED_TOKENS_PATH", os.path.join(cache_dir, "stored_tokens"))
yield
logger = logging.get_logger(__name__)
@pytest.fixture
def fx_cache_dir(request: SubRequest) -> Generator[None, None, None]:
"""Add a `cache_dir` attribute pointing to a temporary directory in tests.
Example:
```py
@pytest.mark.usefixtures("fx_cache_dir")
class TestWithCache(unittest.TestCase):
cache_dir: Path
def test_cache_dir(self) -> None:
self.assertTrue(self.cache_dir.is_dir())
```
"""
with SoftTemporaryDirectory() as cache_dir:
request.cls.cache_dir = cache_dir
yield
# TemporaryDirectory is not super robust on Windows when a git repository is
# cloned in it. See https://www.scivision.dev/python-tempfile-permission-error-windows/.
shutil.rmtree(cache_dir, onerror=set_write_permission_and_retry)
@pytest.fixture(autouse=True)
def disable_symlinks_on_windows_ci(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeSymlinkDict(dict):
def __contains__(self, __o: object) -> bool:
return True # consider any `cache_dir` to be already checked
def __getitem__(self, __key: str) -> bool:
return False # symlinks are never supported
if os.name == "nt" and os.environ.get("DISABLE_SYMLINKS_IN_WINDOWS_TESTS"):
monkeypatch.setattr(
huggingface_hub.file_download,
"_are_symlinks_supported_in_dir",
FakeSymlinkDict(),
)
@pytest.fixture(autouse=True)
def disable_experimental_warnings(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(huggingface_hub.constants, "HF_HUB_DISABLE_EXPERIMENTAL_WARNING", True)
@pytest.fixture(scope="module")
def vcr_config():
return {
"filter_headers": ["authorization", "user-agent", "cookie"],
"ignore_localhost": True,
"path_transformer": lambda path: path + ".yaml",
}