Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve and refine MLP tests for extensibility and A/B testing #8561

Merged
merged 10 commits into from
Jan 15, 2025
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
Expand Down
Empty file added test/spmd/__init__.py
Empty file.
192 changes: 55 additions & 137 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,139 +1,57 @@
import args_parse
import numpy as np
import torch
from torch import nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.utils.checkpoint as checkpoint
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
import torch.optim as optim
from torch import nn

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 1024,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = args_parse.parse_common_options(
batch_size=128, num_epochs=1, opts=MODEL_OPTS.items())

xr.use_spmd(auto=FLAGS.auto_spmd)


class SimpleLinear(nn.Module):

def __init__(self):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(FLAGS.input_dim // 2, 3)
# Add an additional 3x3 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(3, 3)

def forward(self, x):
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)


device = xm.xla_device()


def train():
print('===> Preparing data..')
lr = 0.1
train_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim),
torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)
torch.manual_seed(42)
model = SimpleLinear().to(device)
import argparse
from contextlib import contextmanager
import os
import sys
import unittest

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.fc1.weight, mesh, (0, 1))
xs.mark_sharding(model.fc2.weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.fc1.weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.fc2.weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(model):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = torch_xla.utils.checkpoint.checkpoint(layer, x)
output = x
else:
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
xm.mark_step()
if step % 10 == 0:
assert loss != 0, "Loss should not 0 here"
print(f"Epoch {epoch} step {step} loss {loss}")

for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)

return model


if FLAGS.profile:
server = xp.start_server(FLAGS.profiler_port)
import torch

print('Start training loop...')
m = train()
t = torch.randn(10, FLAGS.input_dim).to(device)
m(t).cpu()
import test_xla_sharding_base

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from utils.train_spmd_linear_model import train_and_evaluate

SKIP_GRADIENT_CHECKPOINTING: bool = False


@contextmanager
def extended_argv(args):
original_argv = sys.argv[:]
sys.argv.extend(args)
try:
yield
finally:
sys.argv = original_argv


class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):

def test_basic(self):
print('Training loop with baseline')
with extended_argv([]):
baseline_losses, baseline_result = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in baseline_losses)
# Verify that the model produces non-zero outputs.
assert not torch.any(baseline_result == 0)

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with gradient checkpointing')
with extended_argv(['--use_gradient_checkpointing']):
checkpointing_losses, checkpointing_result = train_and_evaluate()
# Verify that the runs match with and without checkpointing.
assert torch.allclose(baseline_result, checkpointing_result)
assert all(
torch.allclose(baseline_loss, checkpointing_loss)
for baseline_loss, checkpointing_loss in zip(
baseline_losses, checkpointing_losses))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--skip-gradient-checkpointing', action='store_true')
parsed_args, remaining_argv = parser.parse_known_args()
SKIP_GRADIENT_CHECKPOINTING = parsed_args.skip_gradient_checkpointing
test = unittest.main(argv=[sys.argv[0]] + remaining_argv)
sys.exit(0 if test.result.wasSuccessful() else 1)
Empty file added test/utils/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions test/utils/train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import sys
rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved
from typing import Optional

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import args_parse
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
from torch_xla.utils.checkpoint import checkpoint

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = {}
PROFILER_SERVER = None


class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2),
nn.ReLU(),
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, self.NUM_CLASSES),
)

def forward(self, x):
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(self.layers):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = checkpoint(layer, x)
else:
x = layer(x)
else:
x = self.layers(x)
return x


def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
xs.mark_sharding(model.layers[2].weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.layers[2].weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
xm.mark_step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

losses = []
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return losses, model


def train_and_evaluate():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}

global PROFILER_SERVER, FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
if FLAGS.profile:
PROFILER_SERVER = xp.start_server(FLAGS.profiler_port)
xr.use_spmd(auto=FLAGS.auto_spmd)
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu() for loss in losses], m(t).cpu()
Loading