Skip to content

Commit

Permalink
POC: Allow for a data_seed (#3150)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Oct 9, 2024
1 parent 21c994c commit fd9880d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,7 @@ def prepare_data_loader(
even_batches=self.even_batches,
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
)
Expand Down
9 changes: 8 additions & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ class SeedableRandomSampler(RandomSampler):
"""

def __init__(self, *args, **kwargs):
data_seed = kwargs.pop("data_seed", None)
super().__init__(*args, **kwargs)

self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
self.epoch = 0
self.initial_seed = torch.random.initial_seed()

def __iter__(self):
if self.generator is None:
Expand Down Expand Up @@ -937,6 +939,7 @@ def prepare_data_loader(
even_batches: bool = True,
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
data_seed: Optional[int] = None,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
) -> DataLoader:
Expand Down Expand Up @@ -996,6 +999,9 @@ def prepare_data_loader(
reproducability. Comes at a cost of potentially different performances due to different shuffling
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
`self.set_epoch`
data_seed (`int`, *optional*, defaults to `None`):
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
will use the current default seed from torch.
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
Expand Down Expand Up @@ -1069,6 +1075,7 @@ def prepare_data_loader(
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
data_seed=data_seed,
)

if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
Expand Down
29 changes: 29 additions & 0 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,34 @@ def check_seedable_sampler_in_batch_sampler_shard():
), "Sampler in BatchSamplerShard is not SeedableRandomSampler."


def check_seedable_sampler_with_data_seed():
# Set seed
set_seed(42)
data_seed = 42
train_set = RegressionDataset(length=10, seed=42)
train_dl = DataLoader(train_set, batch_size=2, shuffle=True)

config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed)
accelerator = Accelerator(dataloader_config=config)
prepared_dl = accelerator.prepare(train_dl)
original_items = []
for _ in range(3):
for batch in prepared_dl:
original_items.append(batch["x"])
original_items = torch.cat(original_items)

# Set new data seed
config.data_seed = 43
accelerator = Accelerator(dataloader_config=config)
prepared_dl = accelerator.prepare(train_dl)
new_items = []
for _ in range(3):
for batch in prepared_dl:
new_items.append(batch["x"])
new_items = torch.cat(new_items)
assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed."


def mock_training(length, batch_size, generator, use_seedable_sampler=False):
set_seed(42)
generator.manual_seed(42)
Expand Down Expand Up @@ -800,6 +828,7 @@ def main():
central_dl_preparation_check()
custom_sampler_check()
check_seedable_sampler()
check_seedable_sampler_with_data_seed()

if state.num_processes > 1:
check_seedable_sampler_in_batch_sampler_shard()
Expand Down
10 changes: 10 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,9 @@ class DataLoaderConfiguration:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
data_seed (`int`, defaults to `None`):
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
will use the current default seed from torch.
non_blocking (`bool`, defaults to `False`):
If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device
transfers, allowing for better overlap between dataloader communication and computation. Recommended that
Expand Down Expand Up @@ -781,6 +784,13 @@ class DataLoaderConfiguration:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
data_seed: int = field(
default=None,
metadata={
"help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator"
" will use the current default seed from torch."
},
)
non_blocking: bool = field(
default=False,
metadata={
Expand Down

0 comments on commit fd9880d

Please sign in to comment.