Skip to content

Commit

Permalink
Decouple and run MLP runs for comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 13, 2025
1 parent e75078b commit e9e35ab
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 154 deletions.
3 changes: 1 addition & 2 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ function run_xla_op_tests3 {
#run_test "$CDIR/spmd/test_dtensor_integration2.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
#run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --use_gradient_checkpointing
run_test "$CDIR/spmd/test_train_spmd_linear_model_grad_checkpointing.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_fsdp_v2.py"
Expand Down
Empty file added test/spmd/__init__.py
Empty file.
163 changes: 13 additions & 150 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,160 +1,23 @@
import sys
from typing import Optional
import unittest

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import args_parse
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
from torch_xla.utils.checkpoint import checkpoint
import test_xla_sharding_base
from ..utils.train_spmd_linear_model import train_and_evaluate

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = {}
PROFILER_SERVER = None
class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):


class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2),
nn.ReLU(),
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, self.NUM_CLASSES),
)

def forward(self, x):
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(self.layers):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = checkpoint(layer, x)
else:
x = layer(x)
else:
x = self.layers(x)
return x


def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
xs.mark_sharding(model.layers[2].weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.layers[2].weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
xm.mark_step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

losses = []
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return losses, model


def train_and_evaluate():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}

global PROFILER_SERVER, FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
if FLAGS.profile:
PROFILER_SERVER = xp.start_server(FLAGS.profiler_port)
xr.use_spmd(auto=FLAGS.auto_spmd)
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu().item() for loss in losses], m(t).cpu()
def test_basic(self):
print('Training loop with baseline')
losses, result = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in losses)
# Verify that the model produces non-zero outputs.
assert not torch.any(result == 0)


if __name__ == '__main__':
losses, result = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in losses)
# Verify that the model produces non-zero outputs.
assert torch.all(result != 0)
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
47 changes: 47 additions & 0 deletions test/spmd/test_train_spmd_linear_model_grad_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from contextlib import contextmanager
import sys
import unittest

import torch

import test_xla_sharding_base
from ..utils.train_spmd_linear_model import train_and_evaluate


@contextmanager
def extended_argv(args):
original_argv = sys.argv[:]
sys.argv.extend(args)
try:
yield
finally:
sys.argv = original_argv


class TestSPMDLinearModelGradientCheckpointing(
test_xla_sharding_base.XlaShardingTest):

def test_gradient_checkpoint_matches(self):
"""Verify that gradient checkpointing produces the same results and losses as the baseline."""

print('Training loop with baseline')
with extended_argv([]):
baseline_losses, baseline_result = train_and_evaluate()

print('Training loop with gradient checkpointing')
with extended_argv(['--use_gradient_checkpointing']):
checkpointing_losses, checkpointing_result = train_and_evaluate()

# Verify that the model losses are not zero, and that the runs match.
assert all(loss != 0 for loss in baseline_losses)
assert all(
torch.allclose(baseline_loss, checkpointing_loss) for baseline_loss,
checkpointing_loss in zip(baseline_losses, checkpointing_losses))
# Verify that the model produces non-zero outputs, and that the runs match.
assert not torch.any(baseline_result == 0)
assert torch.allclose(baseline_result, checkpointing_result)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
3 changes: 1 addition & 2 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ run_save_tensor_hlo python3 "$TEST_CDIR/spmd/test_spmd_lowering_context.py"
python3 "$TEST_CDIR/spmd/test_xla_sharding.py"
python3 "$TEST_CDIR/spmd/test_xla_virtual_device.py"
python3 "$TEST_CDIR/spmd/test_xla_distributed_checkpoint.py"
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py"
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py" "$@" --use_gradient_checkpointing
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model_grad_checkpointing.py"
python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py"
python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py"
python3 "$TEST_CDIR/spmd/test_fsdp_v2.py"
Expand Down
Empty file added test/utils/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions test/utils/train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import sys
from typing import Optional

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import args_parse
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
from torch_xla.utils.checkpoint import checkpoint

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = {}
PROFILER_SERVER = None


class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2),
nn.ReLU(),
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, self.NUM_CLASSES),
)

def forward(self, x):
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(self.layers):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = checkpoint(layer, x)
else:
x = layer(x)
else:
x = self.layers(x)
return x


def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
xs.mark_sharding(model.layers[2].weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.layers[2].weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
xm.mark_step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

losses = []
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return losses, model


def train_and_evaluate():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}

global PROFILER_SERVER, FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
if FLAGS.profile:
PROFILER_SERVER = xp.start_server(FLAGS.profiler_port)
xr.use_spmd(auto=FLAGS.auto_spmd)
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu() for loss in losses], m(t).cpu()

0 comments on commit e9e35ab

Please sign in to comment.