diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index c71c7d42311..b5ee438c3b3 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -418,6 +418,9 @@ def spawn( } def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor: + if group == dist.GroupMember.NON_GROUP_MEMBER: + return tensor + if op not in self._reduce_op_map: raise ValueError(f"Unsupported reduction operation: '{op}'") if group is not None and not isinstance(group, dist.ProcessGroup): diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 1f3ad55dd84..d94f002d0a3 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -120,26 +120,21 @@ def _test_distrib_all_reduce(device): def _test_distrib_all_reduce_group(device): if idist.get_world_size() > 1 and idist.backend() is not None: - ranks = [0, 1] + ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1] rank = idist.get_rank() - t = torch.tensor([rank], device=device) bnd = idist.backend() - group = idist.new_group(ranks) - if bnd in ("horovod"): - with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"): + for group in [idist.new_group(ranks), ranks]: + t = torch.tensor([rank], device=device) + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"): + res = idist.all_reduce(t, group=group) + else: res = idist.all_reduce(t, group=group) - else: - res = idist.all_reduce(t, group=group) - assert res == torch.tensor([sum(ranks)], device=device) - - t = torch.tensor([rank], device=device) - if bnd in ("horovod"): - with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"): - res = idist.all_reduce(t, group=ranks) - else: - res = idist.all_reduce(t, group=ranks) - assert res == torch.tensor([sum(ranks)], device=device) + if rank in ranks: + assert res == torch.tensor([sum(ranks)], device=device) + else: + assert res == t ranks = "abc" @@ -218,33 +213,23 @@ def _test_distrib_all_gather(device): def _test_distrib_all_gather_group(device): - if idist.get_world_size() > 1: + if idist.get_world_size() > 1 and idist.backend() is not None: ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1] + sorted_ranks = sorted(ranks) rank = idist.get_rank() bnd = idist.backend() t = torch.tensor([rank], device=device) - group = idist.new_group(ranks) - if bnd in ("horovod"): - with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): - res = idist.all_gather(t, group=group) - else: - res = idist.all_gather(t, group=group) - if rank in ranks: - assert torch.equal(res, torch.tensor(ranks, device=device)) - else: - assert res == t - - t = torch.tensor([rank], device=device) - if bnd in ("horovod"): - with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): - res = idist.all_gather(t, group=ranks) - else: - res = idist.all_gather(t, group=ranks) - if rank in ranks: - assert torch.equal(res, torch.tensor(ranks, device=device)) + for group in [idist.new_group(ranks), ranks]: + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + res = idist.all_gather(t, group=group) else: - assert res == t + res = idist.all_gather(t, group=group) + if rank in ranks: + assert (res == torch.tensor(sorted_ranks, device=device)).all(), (res, ranks) + else: + assert res == t t = { "a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)], @@ -261,7 +246,7 @@ def _test_distrib_all_gather_group(device): res = idist.all_gather(t, group=ranks) if rank in ranks: assert isinstance(res, list) and len(res) == len(ranks) - for i, obj in zip(ranks, res): + for i, obj in zip(sorted_ranks, res): assert isinstance(obj, dict) assert list(obj.keys()) == ["a", "b", "c"], obj expected_device = ( @@ -295,20 +280,44 @@ def _test_idist_all_gather_tensors_with_shapes(device): torch.manual_seed(41) rank = idist.get_rank() ws = idist.get_world_size() - reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + reference = torch.randn(ws + 6, ws + 6, ws + 6, device=device) + + ref_indices_per_rank = { + # rank: (start_index, end_index, size) + r: (r + 1, 2 * r + 2, r + 1) + for r in range(ws) + } + start_index, end_index, _ = ref_indices_per_rank[rank] rank_tensor = reference[ - rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, - rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, - rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + start_index : end_index + 1, + start_index : end_index + 2, + start_index : end_index + 3, ] - tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)]) + tensors = all_gather_tensors_with_shapes( + rank_tensor, + [ + [ + ref_indices_per_rank[r][2] + 1, + ref_indices_per_rank[r][2] + 2, + ref_indices_per_rank[r][2] + 3, + ] + for r in range(ws) + ], + ) for r in range(ws): - r_tensor = reference[ - r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, - r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, - r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + start_index, end_index, _ = ref_indices_per_rank[r] + ref_tensor = reference[ + start_index : end_index + 1, + start_index : end_index + 2, + start_index : end_index + 3, ] - assert (r_tensor == tensors[r]).all() + assert torch.allclose(ref_tensor, tensors[r]), ( + r, + ref_tensor.shape, + ref_tensor.mean(), + tensors[r].shape, + tensors[r].mean(), + ) def _test_idist_all_gather_tensors_with_shapes_group(device): @@ -320,27 +329,29 @@ def _test_idist_all_gather_tensors_with_shapes_group(device): ws = idist.get_world_size() bnd = idist.backend() if rank in ranks: - reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + reference = torch.randn( + ws * (ws + 1) // 2 + 1, ws * (ws + 3) // 2 + 1, ws * (ws + 5) // 2 + 1, device=device + ) rank_tensor = reference[ - rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, - rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, - rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 2, + rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 3, + rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 4, ] else: rank_tensor = torch.tensor([rank], device=device) if bnd in ("horovod"): with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): - tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 2, r + 3, r + 4] for r in ranks], ranks) else: - tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 2, r + 3, r + 4] for r in ranks], ranks) if rank in ranks: for r in ranks: r_tensor = reference[ - r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, - r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, - r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + r * (r + 1) // 2 : r * (r + 1) // 2 + r + 2, + r * (r + 3) // 2 : r * (r + 3) // 2 + r + 3, + r * (r + 5) // 2 : r * (r + 5) // 2 + r + 4, ] - assert (r_tensor == tensors[r - 1]).all() + assert r_tensor.allclose(tensors[r - 1]) else: assert [rank_tensor] == tensors diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index ee828ef5d9f..965bc5ccb31 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -244,6 +244,8 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl): device = idist.device() _test_distrib_all_gather(device) _test_distrib_all_gather_group(device) + _test_idist_all_gather_tensors_with_shapes(device) + _test_idist_all_gather_tensors_with_shapes_group(device) @pytest.mark.distributed @@ -253,21 +255,6 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo): device = idist.device() _test_distrib_all_gather(device) _test_distrib_all_gather_group(device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_idist_all_gather_tensors_with_shapes_nccl(distributed_context_single_node_nccl): - device = idist.device() - _test_idist_all_gather_tensors_with_shapes(device) - _test_idist_all_gather_tensors_with_shapes_group(device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") -def test_idist_all_gather_tensors_with_shapes_gloo(distributed_context_single_node_gloo): - device = idist.device() _test_idist_all_gather_tensors_with_shapes(device) _test_idist_all_gather_tensors_with_shapes_group(device) diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh index f52988a6818..65e8617082b 100644 --- a/tests/run_cpu_tests.sh +++ b/tests/run_cpu_tests.sh @@ -22,7 +22,7 @@ fi # Run 2 processes with --dist=each CUDA_VISIBLE_DEVICES="" run_tests \ --core_args "-m distributed -vvv tests/ignite" \ - --world_size 2 \ + --world_size 4 \ --cache_dir ".cpu-distrib" \ --skip_distrib_tests 0 \ --use_coverage 1 \