From 5167a94d02b226e196a2c149667298d44b654108 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 10 Feb 2025 22:30:39 -0800 Subject: [PATCH] run precommit --- torchdata/stateful_dataloader/sampler.py | 27 +++++++----------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index dd2ddb6aa..0c4164976 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -77,9 +77,7 @@ def __next__(self): remainder = self.num_samples % self.n if self.chunk_index < num_full_perms: if self.perm is None or not self.perm: - self.perm = torch.randperm( - self.n, generator=self.sampler.generator - ).tolist() + self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist() self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -90,9 +88,7 @@ def __next__(self): return value elif remainder > 0: if self.perm is None or not self.perm: - self.perm = torch.randperm( - self.n, generator=self.sampler.generator - ).tolist()[:remainder] + self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder] self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -137,13 +133,9 @@ def __init__( generator.manual_seed(1) self.generator = generator if not isinstance(self.replacement, bool): - raise TypeError( - f"replacement should be a boolean value, but got replacement={self.replacement}" - ) + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") if not isinstance(self.num_samples, int) or self.num_samples <= 0: - raise ValueError( - f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" - ) + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") @property def num_samples(self) -> int: @@ -202,18 +194,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not ( - isinstance(self.sampler, Stateful) - or isinstance(self.sampler_iter, Stateful) - ) and not isinstance(self.sampler, _InfiniteConstantSampler): + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) def update_state_dict(self) -> None: - if isinstance(self.sampler_iter, Stateful) and hasattr( - self.sampler_iter, "update_state_dict" - ): + if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): self.sampler_iter.update_state_dict()