You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm evaluating PyTorch-XLA for training, but noticed that there is a big degradation in performance compared to the native pytorch device. Is it a known problem, or is there a problem with the way I use PyTorch-XLA? I tested a simple MNIST training example, comparing the performance between PyTorch CUDA device and XLA CUDA device. The native CUDA device is twice faster.
Appreciate any thoughts, suggestions or links to known performance issues, thanks!
Environment
note: there is no difference in performance measurements with the latest 2.5.0
torch 2.4.0
torch-xla 2.4.0
torch_xla_cuda_plugin 2.4.0.dev20240902
torchvision 0.19.0
How To Reproduce
Run the test program with xla = True and xla = False
importosfromtqdmimporttqdmimporttorchimporttorch.nnasnnimporttorch.optimasoptimfromtorchvisionimportdatasets, transformsfromtorch.utils.dataimportDataLoaderimporttorch_xla.core.xla_modelasxmdefget_device(xla):
ifxla:
os.environ["PJRT_DEVICE"] ="CUDA"os.environ["GPU_NUM_DEVICES"] ="1"importtorch_xla_cuda_pluginfromtorch_xla.experimentalimportpluginsimporttorch_xla.runtimeasxrplugins.use_dynamic_plugins()
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin())
xr.set_device_type('CUDA')
device=xm.xla_device(devkind="CUDA")
else:
device=torch.device('cuda:0')
os.environ["PJRT_DEVICE"] ="CUDA"os.environ["GPU_NUM_DEVICES"] ="1"returndevicexla=Truedevice=get_device(xla)
print(f"Using device: {device}")
classSimpleNN(nn.Module):
def__init__(self):
super(SimpleNN, self).__init__()
self.fc1=nn.Linear(28*28, 512) # number of neuronsself.fc2=nn.Linear(512, 256) # number of neuronsself.fc3=nn.Linear(256, 10) # Output layer (10 classes for digits 0-9)defforward(self, x):
x=x.view(-1, 28*28) # Flatten the imagex=torch.relu(self.fc1(x)) # Apply ReLU activationx=torch.relu(self.fc2(x))
x=self.fc3(x)
returnx# Load the MNIST dataset and apply transformationstransform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])
train_dataset=datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset=datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader=DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader=DataLoader(test_dataset, batch_size=64, shuffle=False)
# Initialize the model and move it to the devicemodel=SimpleNN().to(device)
# Define the loss function and optimizercriterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(), lr=0.001)
# Training loop, 20 epochsforepochintqdm(range(20)):
model.train() # Set the model to training moderunning_loss=0.0fordata, targetintqdm(train_loader):
data, target=data.to(device), target.to(device) # Move data to the deviceoptimizer.zero_grad() # Zero the gradientsoutput=model(data) # Get model predictionsloss=criterion(output, target) # Compute the lossloss.backward() # Backpropagate the gradientsoptimizer.step() # Update model parametersrunning_loss+=loss.item()
ifxla:
xm.mark_step()
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
# Test the modelmodel.eval()
correct=0total=0withtorch.no_grad():
fordata, targetintest_loader:
data, target=data.to(device), target.to(device) # Move data to CUDA deviceoutput=model(data)
_, predicted=torch.max(output, 1)
total+=target.size(0)
correct+= (predicted==target).sum().item()
print(f'Accuracy: {100*correct/total}%')
The text was updated successfully, but these errors were encountered:
In your script I didn't see anything that measures time. If you are measuring the time of the entire script, then, in XLA's case it would include the time of tracing & compilation.
Hi,
Thank you for the feedback. I use tqdm to measure the time of each loop iteration. If I understand correctly your comment, the extra time in XLA case is for tracing & compilation? Is there a way to mitigate and optimize these steps?
❓ Questions and Help
I'm evaluating PyTorch-XLA for training, but noticed that there is a big degradation in performance compared to the native pytorch device. Is it a known problem, or is there a problem with the way I use PyTorch-XLA? I tested a simple MNIST training example, comparing the performance between PyTorch CUDA device and XLA CUDA device. The native CUDA device is twice faster.
Appreciate any thoughts, suggestions or links to known performance issues, thanks!
Environment
note: there is no difference in performance measurements with the latest 2.5.0
How To Reproduce
Run the test program with
xla = True
andxla = False
The text was updated successfully, but these errors were encountered: