Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow XLA training performance. #8541

Open
tzstoyanov opened this issue Jan 7, 2025 · 2 comments
Open

Slow XLA training performance. #8541

tzstoyanov opened this issue Jan 7, 2025 · 2 comments

Comments

@tzstoyanov
Copy link

tzstoyanov commented Jan 7, 2025

❓ Questions and Help

I'm evaluating PyTorch-XLA for training, but noticed that there is a big degradation in performance compared to the native pytorch device. Is it a known problem, or is there a problem with the way I use PyTorch-XLA? I tested a simple MNIST training example, comparing the performance between PyTorch CUDA device and XLA CUDA device. The native CUDA device is twice faster.
Appreciate any thoughts, suggestions or links to known performance issues, thanks!

Environment

note: there is no difference in performance measurements with the latest 2.5.0

  • torch 2.4.0
  • torch-xla 2.4.0
  • torch_xla_cuda_plugin 2.4.0.dev20240902
  • torchvision 0.19.0

How To Reproduce

Run the test program with xla = True and xla = False

import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch_xla.core.xla_model as xm

def get_device(xla):
  if xla:
    os.environ["PJRT_DEVICE"] = "CUDA"
    os.environ["GPU_NUM_DEVICES"] = "1"
    import torch_xla_cuda_plugin
    from torch_xla.experimental import plugins
    import torch_xla.runtime as xr
    plugins.use_dynamic_plugins()
    plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin())
    xr.set_device_type('CUDA')
    device = xm.xla_device(devkind="CUDA")
  else:
    device = torch.device('cuda:0')
    os.environ["PJRT_DEVICE"] = "CUDA"
    os.environ["GPU_NUM_DEVICES"] = "1"
  return device

xla = True
device = get_device(xla)
print(f"Using device: {device}")

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)  # number of neurons
        self.fc2 = nn.Linear(512, 256)      # number of neurons
        self.fc3 = nn.Linear(256, 10)       # Output layer (10 classes for digits 0-9)
    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the image
        x = torch.relu(self.fc1(x))  # Apply ReLU activation
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load the MNIST dataset and apply transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the model and move it to the device
model = SimpleNN().to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop, 20 epochs
for epoch in tqdm(range(20)):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for data, target in tqdm(train_loader):
        data, target = data.to(device), target.to(device)  # Move data to the device

        optimizer.zero_grad()  # Zero the gradients
        output = model(data)  # Get model predictions
        loss = criterion(output, target)  # Compute the loss
        loss.backward()  # Backpropagate the gradients
        optimizer.step()  # Update model parameters

        running_loss += loss.item()
        if xla:
          xm.mark_step()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

# Test the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)  # Move data to CUDA device
        output = model(data)
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Accuracy: {100 * correct / total}%')
@qihqi
Copy link
Collaborator

qihqi commented Jan 15, 2025

Hi,

In your script I didn't see anything that measures time. If you are measuring the time of the entire script, then, in XLA's case it would include the time of tracing & compilation.

@tzstoyanov
Copy link
Author

tzstoyanov commented Jan 16, 2025

Hi,
Thank you for the feedback. I use tqdm to measure the time of each loop iteration. If I understand correctly your comment, the extra time in XLA case is for tracing & compilation? Is there a way to mitigate and optimize these steps?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants