Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Simplify balanced_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Nov 27, 2019
1 parent e143aa0 commit 46406d3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 4 additions & 2 deletions torchnlp/_third_party/weighted_random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

class WeightedRandomSampler(Sampler):

def __init__(self, weights, num_samples, replacement=True):
# NOTE: Adapted `WeightedRandomSampler` to accept `num_samples=0`.
def __init__(self, weights, num_samples=None, replacement=True):
# NOTE: Adapted `WeightedRandomSampler` to accept `num_samples=0` and `num_samples=None`.
if num_samples is None:
num_samples = len(weights)
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples < 0:
raise ValueError("num_samples should be a positive integer "
Expand Down
2 changes: 0 additions & 2 deletions torchnlp/samplers/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ def __init__(self, data_source, get_class=identity, get_weight=lambda x: 1, **kw
k: sum([w for c, w in zip(classified, weighted) if k == c]) for k in set(classified)
}
weights = [w / class_totals[c] if w > 0 else 0.0 for c, w in zip(classified, weighted)]
if 'num_samples' not in kwargs:
kwargs['num_samples'] = len(data_source)
super().__init__(weights=weights, **kwargs)

0 comments on commit 46406d3

Please sign in to comment.