Skip to content

Commit

Permalink
Use a random generator for proper convergence validation
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 14, 2025
1 parent c23865a commit bb1dbaf
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@


class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
Expand All @@ -54,7 +55,7 @@ def __init__(self):
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, 3),
nn.Linear(3, self.NUM_CLASSES),
)

def forward(self, x):
Expand All @@ -73,13 +74,14 @@ def forward(self, x):

def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim),
torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)
torch.manual_seed(42)
model = SimpleLinear().to(device)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
Expand Down

0 comments on commit bb1dbaf

Please sign in to comment.