Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix end of epoch StatefulDataLoader restart #1439

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

ramanishsingh
Copy link
Contributor

@ramanishsingh ramanishsingh commented Feb 3, 2025

Add tests to reproduce and fix #1437

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2025
@ramanishsingh ramanishsingh marked this pull request as draft February 3, 2025 23:36
Copy link

pytorch-bot bot commented Feb 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1439

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 093e5f2 with merge base fe6b405 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ramanishsingh
Copy link
Contributor Author

This does not solve the problem as it just restarts the dataloader and produces the same batches again.

@gailweiss
Copy link

if its any help, while creating the issue i noticed that after loading the state dict, the resulting state dict in the dataloader is different from the one that was loaded - for example, by setting "samples_yielded" to 0 when the loaded one had 100 (see the prints in #1437 ), (and possibly more differences - I haven't checked). looking at the code in this commit, it seems that samples_yielded is being set manually - maybe that is the root of the problem?

update stateful_dataloader

run precommit

local changes

update test to test the order of batches

update test

update tests

revert changes in SDL

revert changes in SDL

update tests

run precommit
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 8136e63 to a074b50 Compare February 4, 2025 22:48
@ramanishsingh ramanishsingh marked this pull request as ready for review February 5, 2025 06:37
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler):
for _ in self.sampler_iter:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, sorry to nitpick, but wouldn't we actually not want to skip an epoch when reloading a state dict from an unfinished epoch? i.e. IMO it makes sense to recognise the difference between a state_dict obtained via this process:

for b in dl:
    sd_in = dl.state_dict()  # when the for ends, sd_in will describe an "about-to-finish" state

vs a state_dict obtained via this process:

for b in dl:
    pass
sd_out = dl.state_dict()  # sd_out will describe a "just finished" state

I think it makes sense to have an empty epoch immediately after loading sd_in, but a full one immediately after loading sd_out.

In particular, is it possible that issue #1437 is solved just by the new line self.samples_yielded = state_dict[self._SAMPLES_YIELDED]?

Copy link
Contributor Author

@ramanishsingh ramanishsingh Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gailweiss
I think the code changes take care of that. In the state_dict ['_iterator_finished']: False or True depending on if it is just- finishing or actually-finished.
When we load sd_in, the next epoch is still empty because state_dict ['_iterator_finished']: False

Is this the behavior you expect :

from torchdata.stateful_dataloader import StatefulDataLoader


def get_dl():
    d = list(range(10))
    return StatefulDataLoader(d, batch_size=1, shuffle=True)


dl = get_dl()
for i, b in enumerate(dl):
    if i == 0:
        print(i, b)
    sd_in = dl.state_dict()
print("sd_in", sd_in)


dl = get_dl()
dl.load_state_dict(sd_in)  # load the "about-to-finish" state
batches_after_sdin_load = []
for i, b in enumerate(dl):
    batches_after_sdin_load.append(b)
    if i == 0:
        print(i, b)

print("batches_after_sdin_load", batches_after_sdin_load)
dl = get_dl()
for i, b in enumerate(dl):
    if i == 0:
        print(i, b)

sd_out = (
    dl.state_dict()
)  # when the for ends, sd_out will describe a "just-finished" state
print("sd_out", sd_out)
dl = get_dl()
dl.load_state_dict(sd_out)  # load the "about-to-finish" state
batches_after_sdout_load = []
for i, b in enumerate(dl):
    batches_after_sdout_load.append(b)
    if i == 0:
        print(i, b)

print("batches_after_sdout_load", batches_after_sdout_load)

Output:

0 tensor([5])
sd_in {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': False}
batches_after_sdin_load []
0 tensor([5])
sd_out {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
0 tensor([2])
batches_after_sdout_load [tensor([2]), tensor([8]), tensor([1]), tensor([5]), tensor([6]), tensor([9]), tensor([3]), tensor([7]), tensor([0]), tensor([4])]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, yes, that's great! Thank you and sorry for the false complaint!

@ramanishsingh ramanishsingh marked this pull request as draft February 5, 2025 14:02
@@ -1441,6 +1444,154 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you test this for num_workers=0 and say num_workers=2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed defaults values for num_workers.
Tests below run for both 0 and 2.

next(self.sampler_iter)

# Skip one epoch if we were at the end of the last epoch
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all samplers have __len__? This feels brittle to me

Comment on lines 119 to 120
for _ in self.sampler_iter:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If state_dict is saved at the end of epoch 1/start of epoch 2, why do we need to do this? What if the sampler is not stateful, would it happen twice due to lines 114? Something seems a bit off to me

@ramanishsingh ramanishsingh marked this pull request as ready for review February 5, 2025 22:56
@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 5, 2025

@andrewkho
Thanks.
I took your implementation of BatchSamplerIterator from here.
I find that during the loading of the state dict, if the _StatefulRandomSamplerIterator is at its end, its self.next_yielded value is becoming None due to iter re-init from somewhere.
To tackle that, I am artificially making it 0 by checking if we are at the end of an epoch and exhausting the iterator (Line 534 stateful_dataloader.py) .
I think it is less brittle than checking the length of the sampler and skipping one whole epoch. Please lmk your thoughts.

update state dict if the iterator has finished

add comment about why were updating state dict

run precommit
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 4de1bb4 to 6d49b4f Compare February 7, 2025 07:05
@ramanishsingh
Copy link
Contributor Author

TLDR: After refactoring BatchSampler, the same batch sequence is repeated in the epoch following a reload due to _iterator_finished being True. Update the generator in the state_dict after each iteration to cache the latest state, ensuring RNG resumes correctly even if next_yielded is reset to 0.

Problem:
After breaking the BatchSampler into BatchSampler and _BatchSamplerIterator, we encountered an issue where the same sequence of batches is produced in the epoch immediately following a reload, mirroring the last epoch before saving the state_dict.

Root Cause:
This issue arises because the dl state_dict is saved after the epoch completes, resulting in _iterator_finished being set to True. To illustrate, consider the epoch after reloading as epoch 3. In the state_dict of the RandomSampler (a subset of the dl state_dict), key items include self.next_yielded and the state of the generator. When a StatefulDataLoader (SDL) is instantiated with num_workers = 0 and batches are retrieved, the iter method in SDL is invoked. This method utilizes next_iter_state (or the loaded_state_dict) to obtain an iterator. During this process, the generator, sampler_iter, etc., are reloaded. However, since _iterator_finished is True, the _StatefulSingleProcessDataLoaderIter that was generated is discarded, and a new one is created with state_dict=None. Consequently, we lose the RandomSampler state information because next_yielded is reset to 0, and the generator state remains at the start of epoch 2.

Proposed Solution:
While there may be more efficient solutions, one potential approach (that I have implemented) is to update the generator in the state_dict upon completing an iteration. By doing so, we cache the latest generator state, allowing us to resume RNG production from the correct point even when the RandomSampler is reset with next_yielded = 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
4 participants