forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dataset.py
164 lines (131 loc) · 5.58 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import tempfile
import warnings
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterableDataset, RandomSampler
from torch.utils.data.datasets import \
(CollateIterableDataset, BatchIterableDataset, ListDirFilesIterableDataset,
LoadFilesFromDiskIterableDataset, SamplerIterableDataset)
def create_temp_dir_and_files():
# The temp dir and files within it will be released and deleted in tearDown().
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
temp_dir_path = temp_dir.name
temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
return (temp_dir, temp_file1.name, temp_file2.name, temp_file3.name)
class TestIterableDatasetBasic(TestCase):
def setUp(self):
ret = create_temp_dir_and_files()
self.temp_dir = ret[0]
self.temp_files = ret[1:]
def tearDown(self):
try:
self.temp_dir.cleanup()
except Exception as e:
warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e)))
def test_listdirfiles_iterable_dataset(self):
temp_dir = self.temp_dir.name
dataset = ListDirFilesIterableDataset(temp_dir, '')
for pathname in dataset:
self.assertTrue(pathname in self.temp_files)
def test_loadfilesfromdisk_iterable_dataset(self):
temp_dir = self.temp_dir.name
dataset1 = ListDirFilesIterableDataset(temp_dir, '')
dataset2 = LoadFilesFromDiskIterableDataset(dataset1)
for rec in dataset2:
self.assertTrue(rec[0] in self.temp_files)
self.assertTrue(rec[1].read() == open(rec[0], 'rb').read())
class IterDatasetWithoutLen(IterableDataset):
def __init__(self, ds):
super().__init__()
self.ds = ds
def __iter__(self):
for i in self.ds:
yield i
class IterDatasetWithLen(IterableDataset):
def __init__(self, ds):
super().__init__()
self.ds = ds
self.length = len(ds)
def __iter__(self):
for i in self.ds:
yield i
def __len__(self):
return self.length
class TestFunctionalIterableDataset(TestCase):
def test_collate_dataset(self):
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
ds_len = IterDatasetWithLen(arrs)
ds_nolen = IterDatasetWithoutLen(arrs)
def _collate_fn(batch):
return torch.tensor(sum(batch), dtype=torch.float)
collate_ds = CollateIterableDataset(ds_len, collate_fn=_collate_fn)
self.assertEqual(len(ds_len), len(collate_ds))
ds_iter = iter(ds_len)
for x in collate_ds:
y = next(ds_iter)
self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float))
collate_ds_nolen = CollateIterableDataset(ds_nolen) # type: ignore
with self.assertRaises(NotImplementedError):
len(collate_ds_nolen)
ds_nolen_iter = iter(ds_nolen)
for x in collate_ds_nolen:
y = next(ds_nolen_iter)
self.assertEqual(x, torch.tensor(y))
def test_batch_dataset(self):
arrs = range(10)
ds = IterDatasetWithLen(arrs)
with self.assertRaises(AssertionError):
batch_ds0 = BatchIterableDataset(ds, batch_size=0)
# Default not drop the last batch
batch_ds1 = BatchIterableDataset(ds, batch_size=3)
self.assertEqual(len(batch_ds1), 4)
batch_iter = iter(batch_ds1)
value = 0
for i in range(len(batch_ds1)):
batch = next(batch_iter)
if i == 3:
self.assertEqual(len(batch), 1)
self.assertEqual(batch, [9])
else:
self.assertEqual(len(batch), 3)
for x in batch:
self.assertEqual(x, value)
value += 1
# Drop the last batch
batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True)
self.assertEqual(len(batch_ds2), 3)
value = 0
for batch in batch_ds2:
self.assertEqual(len(batch), 3)
for x in batch:
self.assertEqual(x, value)
value += 1
batch_ds3 = BatchIterableDataset(ds, batch_size=2)
self.assertEqual(len(batch_ds3), 5)
batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True)
self.assertEqual(len(batch_ds4), 5)
ds_nolen = IterDatasetWithoutLen(arrs)
batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5)
with self.assertRaises(NotImplementedError):
len(batch_ds_nolen)
def test_sampler_dataset(self):
arrs = range(10)
ds = IterDatasetWithLen(arrs)
# Default SequentialSampler
sampled_ds = SamplerIterableDataset(ds) # type: ignore
self.assertEqual(len(sampled_ds), 10)
i = 0
for x in sampled_ds:
self.assertEqual(x, i)
i += 1
# RandomSampler
random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) # type: ignore
# Requires `__len__` to build SamplerDataset
ds_nolen = IterDatasetWithoutLen(arrs)
with self.assertRaises(AssertionError):
sampled_ds = SamplerIterableDataset(ds_nolen)
if __name__ == '__main__':
run_tests()