-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsampler.py
178 lines (136 loc) · 5.22 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
class Sampler(object):
"""Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class RandomSampler(Sampler):
"""Samples elements randomly, without replacement.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
class SubsetRandomSampler(Sampler):
"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (list): a list of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
class WeightedRandomSampler(Sampler):
"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
Arguments:
weights (list) : a list of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
"""
def __init__(self, weights, num_samples, replacement=True):
self.weights = torch.DoubleTensor(weights)
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
def __len__(self):
return self.num_samples
class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, k, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
self.k=k
# def __iter__(self):
# batch = []
# for idx in self.sampler:
# batch.append(idx)
# if len(batch) == self.batch_size:
# yield batch
# batch = []
# if len(batch) > 0 and not self.drop_last:
# yield batch
def __iter__(self):
batch = []
if self.k != None:
k_sum=self.batch_size
b_size=self.batch_size
for idx in self.sampler:
batch.append(idx)
if len(batch) == b_size and int(k_sum <= self.k):
yield batch
batch = []
k_sum=k_sum + 4
b_size=k_sum
elif len(batch) == int(self.k):
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
else:
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
# def __iter__k(self):
# batch = []
# k_sum=1.0
# b_size=int((k_sum)*self.batch_size)
# for idx in self.sampler:
# batch.append(idx)
# if len(batch) == b_size and int(k_sum <= self.k):
# yield batch
# batch = []
# k_sum=k_sum+0.1
# b_size=int((k_sum)*self.batch_size)
# elif len(batch) == int(self.k*self.batch_size):
# yield batch
# batch = []
# if len(batch) > 0 and not self.drop_last:
# yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size