diff --git a/torchrun_main.py b/torchrun_main.py index d1569cc..a4263d2 100644 --- a/torchrun_main.py +++ b/torchrun_main.py @@ -15,7 +15,7 @@ import torch import torch.nn as nn import torch.utils.data -import torch.distributed as dist as dist +import torch.distributed as dist from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP,