-
Notifications
You must be signed in to change notification settings - Fork 631
/
Copy pathtest_utils_tqdm.py
235 lines (195 loc) · 9.17 KB
/
test_utils_tqdm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import time
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
from pytest import CaptureFixture
from huggingface_hub.utils import (
SoftTemporaryDirectory,
are_progress_bars_disabled,
disable_progress_bars,
enable_progress_bars,
tqdm,
tqdm_stream_file,
)
class CapsysBaseTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def capsys(self, capsys: CaptureFixture) -> None:
"""Workaround to make capsys work in unittest framework.
Capsys is a convenient pytest fixture to capture stdout.
See https://waylonwalker.com/pytest-capsys/.
Taken from https://github.com/pytest-dev/pytest/issues/2504#issuecomment-309475790.
"""
self.capsys = capsys
class TestTqdmUtils(CapsysBaseTest):
def setUp(self) -> None:
"""Get verbosity to set it back after the tests."""
self._previous_are_progress_bars_disabled = are_progress_bars_disabled()
return super().setUp()
def tearDown(self) -> None:
"""Set back progress bars verbosity as before testing."""
if self._previous_are_progress_bars_disabled:
disable_progress_bars()
else:
enable_progress_bars()
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_tqdm_helpers(self) -> None:
"""Test helpers to enable/disable progress bars."""
disable_progress_bars()
assert are_progress_bars_disabled()
enable_progress_bars()
assert not are_progress_bars_disabled()
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", True)
def test_cannot_enable_tqdm_when_env_variable_is_set(self) -> None:
"""
Test helpers cannot enable/disable progress bars when
`HF_HUB_DISABLE_PROGRESS_BARS` is set.
"""
disable_progress_bars()
assert are_progress_bars_disabled()
with self.assertWarns(UserWarning):
enable_progress_bars()
assert are_progress_bars_disabled() # Still disabled
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", False)
def test_cannot_disable_tqdm_when_env_variable_is_set(self) -> None:
"""
Test helpers cannot enable/disable progress bars when
`HF_HUB_DISABLE_PROGRESS_BARS` is set.
"""
enable_progress_bars()
assert not are_progress_bars_disabled()
with self.assertWarns(UserWarning):
disable_progress_bars()
assert not are_progress_bars_disabled() # Still enabled
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_tqdm_disabled(self) -> None:
"""Test TQDM not outputting anything when globally disabled."""
disable_progress_bars()
for _ in tqdm(range(10)):
pass
captured = self.capsys.readouterr()
self.assertEqual(captured.out, "")
self.assertEqual(captured.err, "")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_tqdm_disabled_cannot_be_forced(self) -> None:
"""Test TQDM cannot be forced when globally disabled."""
disable_progress_bars()
for _ in tqdm(range(10), disable=False):
pass
captured = self.capsys.readouterr()
self.assertEqual(captured.out, "")
self.assertEqual(captured.err, "")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_tqdm_can_be_disabled_when_globally_enabled(self) -> None:
"""Test TQDM can still be locally disabled even when globally enabled."""
enable_progress_bars()
for _ in tqdm(range(10), disable=True):
pass
captured = self.capsys.readouterr()
self.assertEqual(captured.out, "")
self.assertEqual(captured.err, "")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_tqdm_enabled(self) -> None:
"""Test TQDM work normally when globally enabled."""
enable_progress_bars()
for _ in tqdm(range(10)):
pass
captured = self.capsys.readouterr()
self.assertEqual(captured.out, "")
self.assertIn("10/10", captured.err) # tqdm log
def test_tqdm_stream_file(self) -> None:
with SoftTemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / "config.json"
with filepath.open("w") as f:
f.write("#" * 1000)
with tqdm_stream_file(filepath) as f:
while True:
data = f.read(100)
if not data:
break
time.sleep(0.001) # Simulate a delay between each chunk
captured = self.capsys.readouterr()
self.assertEqual(captured.out, "")
self.assertIn("config.json: 100%", captured.err) # log file name
self.assertIn("|█████████", captured.err) # tqdm bar
self.assertIn("1.00k/1.00k", captured.err) # size in B
class TestTqdmGroup(CapsysBaseTest):
def setUp(self):
"""Set up the initial condition for each test."""
super().setUp()
enable_progress_bars() # Ensure all are enabled before each test
def tearDown(self):
"""Clean up after each test."""
super().tearDown()
enable_progress_bars()
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_disable_specific_group(self):
"""Test disabling a specific group only affects that group and its subgroups."""
disable_progress_bars("peft.foo")
assert not are_progress_bars_disabled("peft")
assert not are_progress_bars_disabled("peft.something")
assert are_progress_bars_disabled("peft.foo")
assert are_progress_bars_disabled("peft.foo.bar")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_enable_specific_subgroup(self):
"""Test that enabling a subgroup does not affect the disabled state of its parent."""
disable_progress_bars("peft.foo")
enable_progress_bars("peft.foo.bar")
assert are_progress_bars_disabled("peft.foo")
assert not are_progress_bars_disabled("peft.foo.bar")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", True)
def test_disable_override_by_environment_variable(self):
"""Ensure progress bars are disabled regardless of local settings when environment variable is set."""
with self.assertWarns(UserWarning):
enable_progress_bars()
assert are_progress_bars_disabled("peft")
assert are_progress_bars_disabled("peft.foo")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", False)
def test_enable_override_by_environment_variable(self):
"""Ensure progress bars are enabled regardless of local settings when environment variable is set."""
with self.assertWarns(UserWarning):
disable_progress_bars("peft.foo")
assert not are_progress_bars_disabled("peft.foo")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_partial_group_name_not_affected(self):
"""Ensure groups with similar names but not exactly matching are not affected."""
disable_progress_bars("peft.foo")
assert not are_progress_bars_disabled("peft.footprint")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_nested_subgroup_behavior(self):
"""Test enabling and disabling nested subgroups."""
disable_progress_bars("peft")
enable_progress_bars("peft.foo")
disable_progress_bars("peft.foo.bar")
assert are_progress_bars_disabled("peft")
assert not are_progress_bars_disabled("peft.foo")
assert are_progress_bars_disabled("peft.foo.bar")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_empty_group_is_root(self):
"""Test the behavior with invalid or empty group names."""
disable_progress_bars("")
assert not are_progress_bars_disabled("peft")
enable_progress_bars("123.invalid.name")
assert not are_progress_bars_disabled("123.invalid.name")
@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None)
def test_multiple_level_toggling(self):
"""Test multiple levels of enabling and disabling."""
disable_progress_bars("peft")
enable_progress_bars("peft.foo")
disable_progress_bars("peft.foo.bar.something")
assert are_progress_bars_disabled("peft")
assert not are_progress_bars_disabled("peft.foo")
assert are_progress_bars_disabled("peft.foo.bar.something")
def test_progress_bar_respects_group(self) -> None:
disable_progress_bars("foo.bar")
for _ in tqdm(range(10), name="foo.bar.something"):
pass
captured = self.capsys.readouterr()
assert captured.out == ""
assert captured.err == ""
enable_progress_bars("foo.bar.something")
for _ in tqdm(range(10), name="foo.bar.something"):
pass
captured = self.capsys.readouterr()
assert captured.out == ""
assert "10/10" in captured.err