diff --git a/README.md b/README.md index a473b840010..b86893e1ebd 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ If you're using `DistributedDataParallel`, make the following changes: + dist.init_process_group("xla", init_method='xla://') + + model.to(xm.xla_device()) -+ ddp_model = DDP(model) ++ ddp_model = DDP(model, gradient_as_bucket_view=True) - model = model.to(rank) - ddp_model = DDP(model, device_ids=[rank]) diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index b96bb346643..dcc426bc080 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -464,7 +464,7 @@ " # Optional for TPU v4 and GPU\n", " xm.broadcast_master_param(model)\n", "\n", - " model = DDP(model)\n", + " model = DDP(model, gradient_as_bucket_view=True)\n", "\n", " loss_fn = nn.MSELoss()\n", " optimizer = optim.SGD(model.parameters(), lr=.001)\n", diff --git a/docs/source/learn/pjrt.md b/docs/source/learn/pjrt.md index 97f358f29e5..6fc84bf9de3 100644 --- a/docs/source/learn/pjrt.md +++ b/docs/source/learn/pjrt.md @@ -82,7 +82,7 @@ def _mp_fn(index): + # Optional for TPU v4 and GPU + xm.broadcast_master_param(model) - model = DDP(model) + model = DDP(model, gradient_as_bucket_view=True) loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=.001) diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index f3d32650f4e..dabe61cdc13 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -43,7 +43,7 @@ device](../API_GUIDE.md#running-on-a-single-xla-device). 4. Wrap the model with DDP. ``` python - ddp_model = DDP(model) + ddp_model = DDP(model, gradient_as_bucket_view=True) ``` 5. Finally launch your model with xla specific launcher. @@ -107,7 +107,7 @@ def demo_basic(rank): # create model and move it to XLA device device = xm.xla_device() model = ToyModel().to(device) - ddp_model = DDP(model) + ddp_model = DDP(model, gradient_as_bucket_view=True) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) diff --git a/examples/data_parallel/train_resnet_ddp.py b/examples/data_parallel/train_resnet_ddp.py index 327b3f8cbbc..5d295e12e34 100644 --- a/examples/data_parallel/train_resnet_ddp.py +++ b/examples/data_parallel/train_resnet_ddp.py @@ -18,7 +18,7 @@ def __init__(self): dist.init_process_group('xla', init_method='xla://') super().__init__() self.model = DDP( - self.model, broadcast_buffers=False) + self.model, broadcast_buffers=False, gradient_as_bucket_view=True) self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index cc761c875e7..bec580c3831 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -258,7 +258,9 @@ def train_imagenet(): xm.broadcast_master_param(model) if FLAGS.ddp: - model = DDP(model, broadcast_buffers=False) + # gradient_as_bucket_view=True saves memory and can be used so long as + # we are not calling `detach_()` on the gradients. + model = DDP(model, broadcast_buffers=False, gradient_as_bucket_view=True) writer = None if xm.is_master_ordinal(): diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 315f4b200f6..9e470719f27 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -138,7 +138,7 @@ def train_mnist(flags, **kwargs): xm.broadcast_master_param(model) if flags.ddp: - model = DDP(model) + model = DDP(model, gradient_as_bucket_view=True) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir)