From 637246baf96f07b19b193ed101f34b65cb35cffb Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Wed, 26 Jun 2024 00:19:17 +0800 Subject: [PATCH] Fix incorrect rank value in data splitting (#6994) * Fix incorrect rank value in data splitting (#6990) * Add tests for splitting distributed datasets * make style --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/iterable_dataset.py | 2 +- tests/test_distributed.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d60cbb56b9b..3d0b3ce1cf3 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3013,8 +3013,8 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s [`IterableDataset`]: The iterable dataset to be used on the node at rank `rank`. """ if dataset._distributed: - world_size = world_size * dataset._distributed.world_size rank = world_size * dataset._distributed.rank + rank + world_size = world_size * dataset._distributed.world_size distributed = DistributedConfig(rank=rank, world_size=world_size) return IterableDataset( ex_iterable=dataset._ex_iterable, diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 4cd228f2506..b8e0f56b180 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -55,6 +55,26 @@ def gen(shards): assert len({tuple(x.values()) for ds in datasets_per_rank for x in ds}) == full_size +def test_split_dataset_by_node_iterable_distributed(): + def gen(): + return ({"i": i} for i in range(100)) + + world_size = 3 + num_workers = 3 + full_ds = IterableDataset.from_generator(gen) + full_size = len(list(full_ds)) + datasets_per_rank = [ + split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size) + ] + datasets_per_rank_per_worker = [ + split_dataset_by_node(ds, rank=worker, world_size=num_workers) + for ds in datasets_per_rank + for worker in range(num_workers) + ] + assert sum(len(list(ds)) for ds in datasets_per_rank_per_worker) == full_size + assert len({tuple(x.values()) for ds in datasets_per_rank_per_worker for x in ds}) == full_size + + def test_distributed_shuffle_iterable(): def gen(): return ({"i": i} for i in range(17))