Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Fix import problem with pytorch v1.9.x
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Jul 10, 2021
1 parent 5e7fad9 commit 8935127
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torchnlp/_third_party/weighted_random_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch

from torch.utils.data.sampler import Sampler
from torch._six import int_classes as _int_classes


class WeightedRandomSampler(Sampler):
Expand All @@ -10,7 +9,7 @@ def __init__(self, weights, num_samples=None, replacement=True):
# NOTE: Adapted `WeightedRandomSampler` to accept `num_samples=0` and `num_samples=None`.
if num_samples is None:
num_samples = len(weights)
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
num_samples < 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
Expand Down

0 comments on commit 8935127

Please sign in to comment.