Skip to content

Commit

Permalink
run precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 11, 2025
1 parent 1ac45db commit 5167a94
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def __next__(self):
remainder = self.num_samples % self.n
if self.chunk_index < num_full_perms:
if self.perm is None or not self.perm:
self.perm = torch.randperm(
self.n, generator=self.sampler.generator
).tolist()
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand All @@ -90,9 +88,7 @@ def __next__(self):
return value
elif remainder > 0:
if self.perm is None or not self.perm:
self.perm = torch.randperm(
self.n, generator=self.sampler.generator
).tolist()[:remainder]
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder]
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand Down Expand Up @@ -137,13 +133,9 @@ def __init__(
generator.manual_seed(1)
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError(
f"replacement should be a boolean value, but got replacement={self.replacement}"
)
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError(
f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
)
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")

@property
def num_samples(self) -> int:
Expand Down Expand Up @@ -202,18 +194,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert isinstance(self.sampler_iter, Stateful)
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE])

if not (
isinstance(self.sampler, Stateful)
or isinstance(self.sampler_iter, Stateful)
) and not isinstance(self.sampler, _InfiniteConstantSampler):
if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
self.sampler, _InfiniteConstantSampler
):
# We skip x samples if underlying sampler is not stateful
for _ in range(self.samples_yielded):
next(self.sampler_iter)

def update_state_dict(self) -> None:
if isinstance(self.sampler_iter, Stateful) and hasattr(
self.sampler_iter, "update_state_dict"
):
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
self.sampler_iter.update_state_dict()


Expand Down

0 comments on commit 5167a94

Please sign in to comment.