Skip to content

Commit

Permalink
Set gradient_as_bucket_view=True in test_train_mp_imagenet (#8558)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei committed Jan 14, 2025
1 parent 31b2b3b commit ea399cd
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ If you're using `DistributedDataParallel`, make the following changes:
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ ddp_model = DDP(model)
+ ddp_model = DDP(model, gradient_as_bucket_view=True)

- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@
" # Optional for TPU v4 and GPU\n",
" xm.broadcast_master_param(model)\n",
"\n",
" model = DDP(model)\n",
" model = DDP(model, gradient_as_bucket_view=True)\n",
"\n",
" loss_fn = nn.MSELoss()\n",
" optimizer = optim.SGD(model.parameters(), lr=.001)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/learn/pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _mp_fn(index):

+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model)
model = DDP(model, gradient_as_bucket_view=True)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ device](../API_GUIDE.md#running-on-a-single-xla-device).
4. Wrap the model with DDP.

``` python
ddp_model = DDP(model)
ddp_model = DDP(model, gradient_as_bucket_view=True)
```

5. Finally launch your model with xla specific launcher.
Expand Down Expand Up @@ -107,7 +107,7 @@ def demo_basic(rank):
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
ddp_model = DDP(model)
ddp_model = DDP(model, gradient_as_bucket_view=True)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
Expand Down
2 changes: 1 addition & 1 deletion examples/data_parallel/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
dist.init_process_group('xla', init_method='xla://')
super().__init__()
self.model = DDP(
self.model, broadcast_buffers=False)
self.model, broadcast_buffers=False, gradient_as_bucket_view=True)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


Expand Down
4 changes: 3 additions & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def train_imagenet():
xm.broadcast_master_param(model)

if FLAGS.ddp:
model = DDP(model, broadcast_buffers=False)
# gradient_as_bucket_view=True saves memory and can be used so long as
# we are not calling `detach_()` on the gradients.
model = DDP(model, broadcast_buffers=False, gradient_as_bucket_view=True)

writer = None
if xm.is_master_ordinal():
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def train_mnist(flags, **kwargs):
xm.broadcast_master_param(model)

if flags.ddp:
model = DDP(model)
model = DDP(model, gradient_as_bucket_view=True)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
Expand Down

0 comments on commit ea399cd

Please sign in to comment.