From 10de6ed36fae0554d72d9e693ec341bb84f0d0ee Mon Sep 17 00:00:00 2001 From: Rui <179625410+rpsilva-aws@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:01:54 -0800 Subject: [PATCH] Improve and refine MLP tests for extensibility and A/B testing (#8561) --- test/run_tests.sh | 1 + test/spmd/__init__.py | 0 test/spmd/test_train_spmd_linear_model.py | 192 +++++++--------------- test/utils/__init__.py | 0 test/utils/train_spmd_linear_model.py | 152 +++++++++++++++++ 5 files changed, 208 insertions(+), 137 deletions(-) create mode 100644 test/spmd/__init__.py create mode 100644 test/utils/__init__.py create mode 100644 test/utils/train_spmd_linear_model.py diff --git a/test/run_tests.sh b/test/run_tests.sh index da0ebc15a06..2684c585996 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -232,6 +232,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" run_test "$CDIR/spmd/test_mp_input_sharding.py" + run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" diff --git a/test/spmd/__init__.py b/test/spmd/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index 88f66eab379..08637490a3c 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -1,139 +1,57 @@ -import args_parse -import numpy as np -import torch -from torch import nn -import torch_xla -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -import torch_xla.debug.profiler as xp -import torch_xla.distributed.parallel_loader as pl -import torch_xla.distributed.spmd as xs -import torch_xla.utils.checkpoint as checkpoint -import torch_xla.utils.utils as xu -from torch_xla.distributed.spmd import Mesh -import torch.optim as optim -from torch import nn - -MODEL_OPTS = { - '--sharding': { - 'choices': ['batch', 'megatron-lm', 'fsdp'], - 'nargs': '+', - 'default': [], - }, - '--input_dim': { - 'type': int, - 'default': 16834, - }, - '--train_dataset_len': { - 'type': int, - 'default': 1024 * 1024, - }, - '--use_gradient_checkpointing': { - 'action': 'store_true', - } -} - -FLAGS = args_parse.parse_common_options( - batch_size=128, num_epochs=1, opts=MODEL_OPTS.items()) - -xr.use_spmd(auto=FLAGS.auto_spmd) - - -class SimpleLinear(nn.Module): - - def __init__(self): - super(SimpleLinear, self).__init__() - self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(FLAGS.input_dim // 2, 3) - # Add an additional 3x3 layer at the end to ensure the final layer - # is not sharded. - self.fc3 = nn.Linear(3, 3) - - def forward(self, x): - y = self.relu(self.fc1(x)) - z = self.fc2(y) - return self.fc3(z) - - -device = xm.xla_device() - - -def train(): - print('===> Preparing data..') - lr = 0.1 - train_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim), - torch.zeros(FLAGS.batch_size, dtype=torch.int64)), - sample_count=FLAGS.train_dataset_len // FLAGS.batch_size) - torch.manual_seed(42) - model = SimpleLinear().to(device) +import argparse +from contextlib import contextmanager +import os +import sys +import unittest - num_devices = xr.global_runtime_device_count() - print(f'num_devices: {num_devices}') - # Define a mesh with all devices along one axis - mesh_shape = (num_devices, 1) - device_ids = np.arange(num_devices) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - - if 'batch' in FLAGS.sharding: - train_loader = pl.MpDeviceLoader( - train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) - - if 'fsdp' in FLAGS.sharding: - train_loader = pl.MpDeviceLoader( - train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) - print('Sharding model weights') - # Shard the weights according to their 0th dim - xs.mark_sharding(model.fc1.weight, mesh, (0, 1)) - xs.mark_sharding(model.fc2.weight, mesh, (0, 1)) - - if 'megatron-lm' in FLAGS.sharding: - print('Sharding model weights') - # Shard the first layer's weights row-wise - xs.mark_sharding(model.fc1.weight, mesh, (0, 1)) - # Shard the second layer's weights column-wise - xs.mark_sharding(model.fc2.weight, mesh, (1, 0)) - - optimizer = optim.SGD(model.parameters(), lr=lr) - - loss_fn = nn.CrossEntropyLoss() - - def train_loop_fn(loader, epoch): - model.train() - for step, (data, target) in enumerate(loader): - with xp.StepTrace('train_linear_model'): - with xp.Trace('build_graph'): - x = data.to(device) - y = target.to(device) - optimizer.zero_grad() - if FLAGS.use_gradient_checkpointing: - for n_l, layer in enumerate(model): - # Apply gradient checkpointing for reduced memory footprint. - # This would result in increased computation cost. - if n_l > 0: - x = torch_xla.utils.checkpoint.checkpoint(layer, x) - output = x - else: - output = model(x) - loss = loss_fn(output, y) - loss.backward() - optimizer.step() - xm.mark_step() - if step % 10 == 0: - assert loss != 0, "Loss should not 0 here" - print(f"Epoch {epoch} step {step} loss {loss}") - - for epoch in range(FLAGS.num_epochs): - train_loop_fn(train_loader, epoch) - - return model - - -if FLAGS.profile: - server = xp.start_server(FLAGS.profiler_port) +import torch -print('Start training loop...') -m = train() -t = torch.randn(10, FLAGS.input_dim).to(device) -m(t).cpu() +import test_xla_sharding_base + +parent_folder = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(parent_folder) +from utils.train_spmd_linear_model import train_and_evaluate + +SKIP_GRADIENT_CHECKPOINTING: bool = False + + +@contextmanager +def extended_argv(args): + original_argv = sys.argv[:] + sys.argv.extend(args) + try: + yield + finally: + sys.argv = original_argv + + +class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest): + + def test_basic(self): + print('Training loop with baseline') + with extended_argv([]): + baseline_losses, baseline_result = train_and_evaluate() + # Verify that the model losses are not zero. + assert all(loss != 0 for loss in baseline_losses) + # Verify that the model produces non-zero outputs. + assert not torch.any(baseline_result == 0) + + if not SKIP_GRADIENT_CHECKPOINTING: + print('Training loop with gradient checkpointing') + with extended_argv(['--use_gradient_checkpointing']): + checkpointing_losses, checkpointing_result = train_and_evaluate() + # Verify that the runs match with and without checkpointing. + assert torch.allclose(baseline_result, checkpointing_result) + assert all( + torch.allclose(baseline_loss, checkpointing_loss) + for baseline_loss, checkpointing_loss in zip( + baseline_losses, checkpointing_losses)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--skip-gradient-checkpointing', action='store_true') + parsed_args, remaining_argv = parser.parse_known_args() + SKIP_GRADIENT_CHECKPOINTING = parsed_args.skip_gradient_checkpointing + test = unittest.main(argv=[sys.argv[0]] + remaining_argv) + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/utils/train_spmd_linear_model.py b/test/utils/train_spmd_linear_model.py new file mode 100644 index 00000000000..7fba86d5dab --- /dev/null +++ b/test/utils/train_spmd_linear_model.py @@ -0,0 +1,152 @@ +import sys +from typing import Optional + +import numpy as np +import torch +from torch import nn +import torch.optim as optim + +import args_parse +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.spmd as xs +import torch_xla.runtime as xr +import torch_xla.utils.utils as xu +from torch_xla.distributed.spmd import Mesh +from torch_xla.utils.checkpoint import checkpoint + +MODEL_OPTS = { + '--sharding': { + 'choices': ['batch', 'megatron-lm', 'fsdp'], + 'nargs': '+', + 'default': [], + }, + '--input_dim': { + 'type': int, + 'default': 16834, + }, + '--train_dataset_len': { + 'type': int, + 'default': 1024 * 8, + }, + '--use_gradient_checkpointing': { + 'action': 'store_true', + } +} + +FLAGS = {} +PROFILER_SERVER = None + + +class SimpleLinear(nn.Module): + NUM_CLASSES = 3 + + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2), + nn.ReLU(), + nn.Linear(FLAGS.input_dim // 2, 3), + # # Add an additional 3x3 layer at the end to ensure the final layer + # # is not sharded. + nn.Linear(3, self.NUM_CLASSES), + ) + + def forward(self, x): + if FLAGS.use_gradient_checkpointing: + for n_l, layer in enumerate(self.layers): + # Apply gradient checkpointing for reduced memory footprint. + # This would result in increased computation cost. + if n_l > 0: + x = checkpoint(layer, x) + else: + x = layer(x) + else: + x = self.layers(x) + return x + + +def train(): + device = xm.xla_device() + torch.manual_seed(42) + model = SimpleLinear().to(device) + print('===> Preparing data..') + train_loader = xu.SampleGenerator( + data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim), + torch.randint( + 0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)), + sample_count=FLAGS.train_dataset_len // FLAGS.batch_size) + + num_devices = xr.global_runtime_device_count() + print(f'num_devices: {num_devices}') + # Define a mesh with all devices along one axis + mesh_shape = (num_devices, 1) + device_ids = np.arange(num_devices) + mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + + if 'batch' in FLAGS.sharding: + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) + + if 'fsdp' in FLAGS.sharding: + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) + print('Sharding model weights') + # Shard the weights according to their 0th dim + xs.mark_sharding(model.layers[0].weight, mesh, (0, 1)) + xs.mark_sharding(model.layers[2].weight, mesh, (0, 1)) + + if 'megatron-lm' in FLAGS.sharding: + print('Sharding model weights') + # Shard the first layer's weights row-wise + xs.mark_sharding(model.layers[0].weight, mesh, (0, 1)) + # Shard the second layer's weights column-wise + xs.mark_sharding(model.layers[2].weight, mesh, (1, 0)) + + optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr) + + loss_fn = nn.CrossEntropyLoss() + + def train_loop_fn(loader, epoch): + model.train() + for step, (data, target) in enumerate(loader): + with xp.StepTrace('train_linear_model'): + with xp.Trace('build_graph'): + x = data.to(device) + y = target.to(device) + optimizer.zero_grad() + output = model(x) + loss = loss_fn(output, y) + losses.append(loss.clone().detach()) + loss.backward() + optimizer.step() + xm.mark_step() + if step % FLAGS.log_steps == 0: + print(f"Epoch {epoch} step {step} loss {loss}") + + losses = [] + for epoch in range(FLAGS.num_epochs): + train_loop_fn(train_loader, epoch) + return losses, model + + +def train_and_evaluate(): + default_config = { + 'batch_size': 128, + 'num_epochs': 1, + 'lr': 0.1, + 'log_steps': 8, + 'opts': MODEL_OPTS.items() + } + + global PROFILER_SERVER, FLAGS + FLAGS = args_parse.parse_common_options(**default_config) + if FLAGS.profile: + PROFILER_SERVER = xp.start_server(FLAGS.profiler_port) + xr.use_spmd(auto=FLAGS.auto_spmd) + print('Start training loop...') + losses, m = train() + t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device()) + return [loss.cpu() for loss in losses], m(t).cpu()