From 31b2b3b7b646c01fd2438d02ddd3810ecaa47d39 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Tue, 14 Jan 2025 10:05:23 -0800 Subject: [PATCH] Cherrypick #8521 into r2.6 (#8563) Co-authored-by: Chengji Yao --- docs/source/perf/ddp.md | 7 +++---- test/distributed_util.py | 9 ++++++--- test/torch_distributed/test_ddp.py | 13 +++++++++++-- torch_xla/csrc/aten_xla_bridge.cpp | 10 ++++++++++ torch_xla/csrc/aten_xla_bridge.h | 3 +++ torch_xla/csrc/init_python_bindings.cpp | 3 +++ 6 files changed, 36 insertions(+), 9 deletions(-) diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index 3319a47c2f2..f3d32650f4e 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -244,7 +244,6 @@ repo](https://github.com/pytorch/xla/). For those who are interested in the native xla data parallel approach, here is the [tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing). -Here are some of the known issues that are under investigation: \* -`gradient_as_bucket_view=False` needs to be enforced. \* There are some -issues while being used with `torch.utils.data.DataLoader`. -`test_train_mp_mnist.py` with real data crashes before exiting. +Here are some of the known issues that are under investigation: \* There are some +issues while being used with `torch.utils.data.DataLoader`. `test_train_mp_mnist.py` +with real data crashes before exiting. diff --git a/test/distributed_util.py b/test/distributed_util.py index 4d428e202da..fcecd4b8045 100644 --- a/test/distributed_util.py +++ b/test/distributed_util.py @@ -89,7 +89,8 @@ def train_step(model, inputs, labels, optimizer, loss_fn): def ddp_correctness(init_method: str = 'env://', use_large_net: bool = False, - debug: bool = False): + debug: bool = False, + gradient_as_bucket_view: bool = False): if init_method == 'env://': rank = xr.global_ordinal() world_size = xr.world_size() @@ -111,11 +112,13 @@ def ddp_correctness(init_method: str = 'env://', steps = 5 # To save test time. cpu_model = LargeNet() - # TODO: There're issues in the captured graph when gradient_as_bucket_view is True # bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding # using models that are too larger (25 mb). # To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732. - ddp_model = DDP(copy.deepcopy(cpu_model).to(device), bucket_cap_mb=1) + ddp_model = DDP( + copy.deepcopy(cpu_model).to(device), + gradient_as_bucket_view=gradient_as_bucket_view, + bucket_cap_mb=1) # ddp_model.register_comm_hook(state=None, hook=comp_hook) cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-1) diff --git a/test/torch_distributed/test_ddp.py b/test/torch_distributed/test_ddp.py index 55261f4d295..61a8ef8a593 100644 --- a/test/torch_distributed/test_ddp.py +++ b/test/torch_distributed/test_ddp.py @@ -17,7 +17,10 @@ class TestXrtDistributedDataParallel(parameterized.TestCase): @staticmethod - def _ddp_correctness(rank, use_large_net: bool, debug: bool): + def _ddp_correctness(rank, + use_large_net: bool, + debug: bool, + gradient_as_bucket_view: bool = False): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. device = xm.xla_device() @@ -27,11 +30,17 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool): file=sys.stderr) return util.ddp_correctness( - init_method="xla://", use_large_net=use_large_net, debug=debug) + init_method="xla://", + use_large_net=use_large_net, + debug=debug, + gradient_as_bucket_view=gradient_as_bucket_view) def test_ddp_correctness(self): torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug)) + def test_ddp_correctness_with_gradient_as_bucket_view(self): + torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug, True)) + def test_ddp_correctness_large_net(self): torch_xla.launch(self._ddp_correctness, args=(True, FLAGS.debug)) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 6344aa5d1e5..1b084ecd82a 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -121,6 +121,16 @@ void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) { impl->set_tensor(std::move(new_xla_tensor)); } +void ReplaceXlaTensor(const std::vector& tensors, + const std::vector new_xla_tensors) { + XLA_CHECK(tensors.size() == new_xla_tensors.size()) + << "The size of tensors and new_xla_tensors are not equal: " + << tensors.size() << " vs. " << new_xla_tensors.size(); + for (size_t i = 0; i < tensors.size(); ++i) { + ReplaceXlaTensor(tensors[i], new_xla_tensors[i]); + } +} + std::vector GetXlaTensors(const at::ITensorListRef& tensors) { std::vector xla_tensors; xla_tensors.reserve(tensors.size()); diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index b25a4823c3a..a862e3a72e2 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -29,6 +29,9 @@ XLATensorPtr GetXlaTensor(const at::Tensor& tensor); // version. void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor); +void ReplaceXlaTensor(const std::vector& tensor, + const std::vector new_xla_tensor); + // Same as above, applied to a list of tensors. std::vector GetXlaTensors(const at::ITensorListRef& tensors); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b6748b858de..04dcbf526ed 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -304,6 +304,9 @@ void AllReduceInPlace(const std::string& reduce_type, GetXlaTensors(tensors, /*want_all=*/true); tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale, replica_groups, pin_layout); + std::vector new_xtensors = + GetXlaTensors(tensors, /*want_all=*/true); + bridge::ReplaceXlaTensor(tensors, new_xtensors); } at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,