-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix end of epoch StatefulDataLoader restart #1439
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1439
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 093e5f2 with merge base fe6b405 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This does not solve the problem as it just restarts the dataloader and produces the same batches again. |
if its any help, while creating the issue i noticed that after loading the state dict, the resulting state dict in the dataloader is different from the one that was loaded - for example, by setting "samples_yielded" to 0 when the loaded one had 100 (see the prints in #1437 ), (and possibly more differences - I haven't checked). looking at the code in this commit, it seems that samples_yielded is being set manually - maybe that is the root of the problem? |
update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit
8136e63
to
a074b50
Compare
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler): | ||
for _ in self.sampler_iter: | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, sorry to nitpick, but wouldn't we actually not want to skip an epoch when reloading a state dict from an unfinished epoch? i.e. IMO it makes sense to recognise the difference between a state_dict obtained via this process:
for b in dl:
sd_in = dl.state_dict() # when the for ends, sd_in will describe an "about-to-finish" state
vs a state_dict obtained via this process:
for b in dl:
pass
sd_out = dl.state_dict() # sd_out will describe a "just finished" state
I think it makes sense to have an empty epoch immediately after loading sd_in, but a full one immediately after loading sd_out.
In particular, is it possible that issue #1437 is solved just by the new line self.samples_yielded = state_dict[self._SAMPLES_YIELDED]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gailweiss
I think the code changes take care of that. In the state_dict
['_iterator_finished']: False or True depending on if it is just- finishing or actually-finished.
When we load sd_in, the next epoch is still empty because state_dict
['_iterator_finished']: False
Is this the behavior you expect :
from torchdata.stateful_dataloader import StatefulDataLoader
def get_dl():
d = list(range(10))
return StatefulDataLoader(d, batch_size=1, shuffle=True)
dl = get_dl()
for i, b in enumerate(dl):
if i == 0:
print(i, b)
sd_in = dl.state_dict()
print("sd_in", sd_in)
dl = get_dl()
dl.load_state_dict(sd_in) # load the "about-to-finish" state
batches_after_sdin_load = []
for i, b in enumerate(dl):
batches_after_sdin_load.append(b)
if i == 0:
print(i, b)
print("batches_after_sdin_load", batches_after_sdin_load)
dl = get_dl()
for i, b in enumerate(dl):
if i == 0:
print(i, b)
sd_out = (
dl.state_dict()
) # when the for ends, sd_out will describe a "just-finished" state
print("sd_out", sd_out)
dl = get_dl()
dl.load_state_dict(sd_out) # load the "about-to-finish" state
batches_after_sdout_load = []
for i, b in enumerate(dl):
batches_after_sdout_load.append(b)
if i == 0:
print(i, b)
print("batches_after_sdout_load", batches_after_sdout_load)
Output:
0 tensor([5])
sd_in {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0, ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': False}
batches_after_sdin_load []
0 tensor([5])
sd_out {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0, ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
0 tensor([2])
batches_after_sdout_load [tensor([2]), tensor([8]), tensor([1]), tensor([5]), tensor([6]), tensor([9]), tensor([3]), tensor([7]), tensor([0]), tensor([4])]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, yes, that's great! Thank you and sorry for the false complaint!
@@ -1441,6 +1444,154 @@ def test_fast_state_dict_request_skip_steps(self) -> None: | |||
self._run_test(17, 19) | |||
|
|||
|
|||
class TestMultiEpochSDL_shard0(TestCase): | |||
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you test this for num_workers=0 and say num_workers=2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed defaults values for num_workers
.
Tests below run for both 0 and 2.
next(self.sampler_iter) | ||
|
||
# Skip one epoch if we were at the end of the last epoch | ||
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all samplers have __len__
? This feels brittle to me
for _ in self.sampler_iter: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If state_dict is saved at the end of epoch 1/start of epoch 2, why do we need to do this? What if the sampler is not stateful, would it happen twice due to lines 114? Something seems a bit off to me
@andrewkho |
update state dict if the iterator has finished add comment about why were updating state dict run precommit
4de1bb4
to
6d49b4f
Compare
TLDR: After refactoring BatchSampler, the same batch sequence is repeated in the epoch following a reload due to _iterator_finished being True. Update the generator in the state_dict after each iteration to cache the latest state, ensuring RNG resumes correctly even if next_yielded is reset to 0. Problem: Root Cause: Proposed Solution: |
Add tests to reproduce and fix #1437