Skip to content

Commit

Permalink
run precommit
Browse files Browse the repository at this point in the history
update stateful_dataloader

run precommit

local changes

update test to test the order of batches

update test

update tests

revert changes in SDL

revert changes in SDL

update tests

run precommit
  • Loading branch information
ramanishsingh committed Feb 4, 2025
1 parent 1c69775 commit a074b50
Showing 1 changed file with 51 additions and 16 deletions.
67 changes: 51 additions & 16 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def test(self):
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
it = iter(dl)
# Fetch at least one batch from each worker
Expand All @@ -1325,7 +1325,10 @@ def test(self):
if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertEqual(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
None,
)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
Expand Down Expand Up @@ -1440,48 +1443,80 @@ def test_fast_state_dict_request(self) -> None:
def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)

class TestMultiEpochState(TestCase):

class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False):
dataset = DummyMapDataset(data_size, shuffle=shuffle)
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
multiprocessing_context=(
"forkserver" if IS_MACOS and num_workers else None
),
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, shuffle=False):
dataloader = self.get_map_dl(data_size=data_size,num_workers=num_workers, batch_size=batch_size, shuffle=shuffle)
dl1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 2 epochs and count the number of items yielded
num_items_yielded = 0
dl1_items = []
for _ in range(2):
for _ in dataloader:
for batch in dl1:
dl1_items.append(batch)
num_items_yielded += 1
# Save the state dict
state_dict = dataloader.state_dict()
state_dict = dl1.state_dict()
# Create a new StatefulDataLoader instance and load the state dict
new_dataloader = self.get_map_dl(
num_workers=num_workers, batch_size=batch_size, shuffle=shuffle
new_dl1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
new_dataloader.load_state_dict(state_dict)
new_dl1.load_state_dict(state_dict)
# Run through the new dataloader for another 2 epochs and count the number of items yielded
additional_num_items_yielded = 0
for i in range(2):
epoch_num_items_yielded = 0
for _ in new_dataloader:
for batch in new_dl1:
dl1_items.append(batch)
epoch_num_items_yielded += 1
additional_num_items_yielded += epoch_num_items_yielded
# Check that the total number of items yielded is correct
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size*4)
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)

# now run a second dataloder for 4 epochs and check if the order is same.
dl2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dl2_items = []
for _ in range(4):
for batch in dl2:
dl2_items.append(batch)

self.assertEqual(dl1_items, dl2_items)

def test_main_process(self):
self._run(100, 0, 1, False)

def test_multiprocess(self):
self._run(100, 2, 1, False)

def test_main_process_shuffle(self):
self._run(100, 0, 1, True)

def test_multiprocess_shuffle(self):
self._run(100, 2, 1, True)


class TestMultiEpochState_shard0(TestCase):
def get_iterable_dl(self, pw, num_workers):
data_size = [25, 50, 100, 75]
Expand Down

0 comments on commit a074b50

Please sign in to comment.