Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multigpu training hangs using single and multiple nodes #8549

Open
Patataman opened this issue Jan 10, 2025 · 6 comments
Open

Multigpu training hangs using single and multiple nodes #8549

Patataman opened this issue Jan 10, 2025 · 6 comments

Comments

@Patataman
Copy link

Patataman commented Jan 10, 2025

🐛 Bug

I am trying to run the example codes for distributed training but all my attempts hangs or return an error

To Reproduce

I have tried with test_train_mp_mnist and with the example in PyTorch docs https://pytorch.org/xla/master/learn/pjrt.html#tl-dr

Right now I am trying with the last one because it's simpler

import os

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr


def _mp_fn(index):
  device = xm.xla_device()
  dist.init_process_group('xla', init_method='xla://')

  torch.manual_seed(42)
  model = nn.Linear(128, 10).to(device)

  # Optional for TPU v4 and GPU
  xm.broadcast_master_param(model)
  model = DDP(model)

  loss_fn = nn.MSELoss()
  optimizer = optim.SGD(model.parameters(), lr=.001)

  for i in range(10):
    data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()

    optimizer.step()
    xm.mark_step()

  # Print mean parameters so we can confirm they're the same across replicas
  print([p.mean() for p in model.parameters()])

if __name__ == '__main__':
  torch_xla.launch(_mp_fn)

Steps to reproduce the behavior:

For this example I am using a single machine with 2 GPUs.

When I try to run with 1 GPU to see if the command is correct it gives me the following error:

  • PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 torchrun --nnodes=1 --nproc-per-node=1 example-xla.py --epochs 1
(venv) ~$ PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 torchrun --nnodes=1 --nproc-per-node=1 example-xla.py --epochs 1
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1736507062.028835 2365508 service.cc:148] XLA service 0x5570d36760c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736507062.028872 2365508 service.cc:156]   StreamExecutor device (0): NVIDIA A40, Compute Capability 8.6
I0000 00:00:1736507062.028877 2365508 service.cc:156]   StreamExecutor device (1): NVIDIA A40, Compute Capability 8.6
I0000 00:00:1736507062.030735 2365508 se_gpu_pjrt_client.cc:943] Using BFC allocator.
I0000 00:00:1736507062.030773 2365508 gpu_helpers.cc:114] XLA backend allocating 35802464256 bytes on device 0 for BFCAllocator.
I0000 00:00:1736507062.030796 2365508 gpu_helpers.cc:114] XLA backend allocating 35802464256 bytes on device 1 for BFCAllocator.
I0000 00:00:1736507062.030817 2365508 gpu_helpers.cc:154] XLA backend will use up to 11934154752 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1736507062.030828 2365508 gpu_helpers.cc:154] XLA backend will use up to 11934154752 bytes on device 1 for CollectiveBFCAllocator.
2025-01-10 12:04:22.505221: E external/xla/xla/status_macros.cc:56] INTERNAL: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:477) subgroup_size == 1 || shard_count == subgroup_size shard_count = 1, subgroup_size = 2, %all-gather.6 = s64[1]{0} all-gather(s64[1]{0} %add.5), replica_groups={}, dimensions={0}
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        xla::status_macros::MakeErrorStream::Impl::GetStatus()

        xla::ShapeVerifier::HandleAllGather(xla::HloInstruction*)
        absl::lts_20230802::Status xla::HloInstruction::Visit<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*)

        absl::lts_20230802::Status xla::HloInstruction::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*, bool, bool, bool)
        absl::lts_20230802::Status xla::HloComputation::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*) const
        xla::HloVerifier::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        xla::CreateModuleFromProto(xla::HloModuleProto const&, xla::HloModuleConfig const&, bool)
        xla::Service::BuildExecutable(xla::HloModuleProto const&, std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> >, xla::Backend*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, bool)
        xla::LocalService::CompileExecutables(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::LocalClient::Compile(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::PjRtStreamExecutorClient::CompileInternal(xla::XlaComputation const&, std::vector<xla::Shape const*, std::allocator<xla::Shape const*> > const&, std::function<absl::lts_20230802::StatusOr<std::pair<std::vector<xla::Shape, std::allocator<xla::Shape> >, xla::Shape> > (xla::HloModule const&)>, xla::CompileOptions)
        xla::PjRtStreamExecutorClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
        xla::StreamExecutorGpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
        torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::GetTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*)
        torch_xla::bridge::XlaCreateTensorList(c10::IListRef<at::Tensor> const&)
        torch_xla::XLANativeFunctions::_to_cpu(c10::ArrayRef<at::Tensor>)

        at::_ops::_to_cpu::call(c10::ArrayRef<at::Tensor>)

        at::native::cpu_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*, bool, c10::DispatchKey)
        torch_xla::xla_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*)
        at::native::_call_fallback_fn<&torch_xla::xla_fallback, at::_ops::_local_scalar_dense, false, c10::Scalar (at::Tensor const&)>::call(at::Tensor const&)
        torch_xla::XLANativeFunctions::_local_scalar_dense(at::Tensor const&)

        c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const


        at::_ops::_local_scalar_dense::redispatch(c10::DispatchKeySet, at::Tensor const&)


        at::_ops::_local_scalar_dense::call(at::Tensor const&)
        at::native::item(at::Tensor const&)

        at::_ops::item::call(at::Tensor const&)
        int at::Tensor::item<int>() const
        c10d::verify_params_across_processes(c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::weak_ptr<c10d::Logger> > const&)



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyObject_FastCallDictTstate

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***

[rank0]: Traceback (most recent call last):
[rank0]:   File "example-xla.py", line 44, in <module>
[rank0]:     torch_xla.launch(_mp_fn)
[rank0]:   File "xla/venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 231, in launch
[rank0]:     fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
[rank0]:   File "xla/unet/ddp/example-xla.py", line 24, in _mp_fn
[rank0]:     model = DDP(model)
[rank0]:   File "xla/venv/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 825, in __init__
[rank0]:     _verify_param_shape_across_processes(self.process_group, parameters)
[rank0]:   File "xla/venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 288, in _verify_param_shape_across_processes
[rank0]:     return dist._verify_params_across_processes(process_group, tensors, logger)
[rank0]: RuntimeError: Bad StatusOr access: INTERNAL: during context [Unknown]: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:477) subgroup_size == 1 || shard_count == subgroup_size shard_count = 1, subgroup_size = 2, %all-gather.6 = s64[1]{0} all-gather(s64[1]{0} %add.5), replica_groups={}, dimensions={0}
E0110 12:04:23.262000 2365422 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 2365508) of binary: xla/venv/bin/python
Traceback (most recent call last):
  File "xla/venv/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
    run(args)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

On the other side, if I tried to run with both GPUs to have parallelism, it just hangs. When I cancel the execution the traceback is this, it appears to hang on torchxla.launch

  • torchrun --nnodes=1 --nproc-per-node=2 example-xla.py --epochs 1

Environment

$> pip freeze | grep torch
torch==2.5.1
torch-xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=86d0a9af00fb678f903e5c4968e30dca3c50d6ce64aa33da9314b2134418ace3
torchvision==0.20.1

CUDA 12.5
Driver 555.42.06

Additional context

I can successfully execute other non-parallel xla scripts.
When I try to use multiple nodes with torchrun it also hangs, while the same command with non-xla scripts works perfectly.

@yaochengji
Copy link
Collaborator

yaochengji commented Jan 13, 2025

Hi,

Thanks for reporting this issue.

Could you try model = DDP(model, gradient_as_bucket_view=True)? There's a bug in the previous version of torch/xla when gradient_as_bucket_view is not set to True

@Patataman
Copy link
Author

Hello, thanks for the response.

I am afraid that model = DDP(model, gradient_as_bucket_view=True) didn't do anything. Same results...

@yaochengji
Copy link
Collaborator

Could you try to change your code to

import os

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr


def _mp_fn(index):
  device = xm.xla_device()
  dist.init_process_group('xla', init_method='xla://')

  torch.manual_seed(42)
  model = nn.Linear(10, 10).to(device)

  # Optional for TPU v4 and GPU
  xm.broadcast_master_param(model)
  model = DDP(model)

  loss_fn = nn.MSELoss()
  optimizer = optim.SGD(model.parameters(), lr=.001)

  for i in range(10):
    data, target = torch.randn((128, 10), device=device), torch.randn((128, 10), device=device)

    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()

    optimizer.step()
    xm.mark_step()

  # Print mean parameters so we can confirm they're the same across replicas
  print([p.mean() for p in model.parameters()])

if __name__ == '__main__':
  torch_xla.launch(_mp_fn)

I only changed the shape of the Linear. Then it passed on side. Seems the XLA compiler failed for the previous shape.

@Patataman
Copy link
Author

Thanks, now the script successfully runs for a single GPU, but keeps hanging for more than 1 GPU (either on the same machine or on different machines using torchrun) I tried to run with PT_XLA_DEBUG_LEVEL=2 but I didn't get any output.

I'm seeing that I forgot to upload the traceback when cancelling the multigpu command, so I put it here

W0117 10:14:46.171000 2686085 torch/distributed/elastic/agent/server/api.py:704] Received Signals.SIGINT death signal, shutting down workers
W0117 10:14:46.181000 2686085 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2686151 closing signal SIGINT
^CW0117 10:14:46.328000 2686085 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2686151 closing signal SIGTERM
^CTraceback (most recent call last):
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
    result = self._invoke_run(role)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
    time.sleep(monitor_interval)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2686085 got signal: 2

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 705, in run
    self._shutdown(e.sigval)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
    self._pcontext.close(death_sig)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
    self._close(death_sig=death_sig, timeout=timeout)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
    handler.proc.wait(time_to_wait)
  File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
    return self._wait(timeout=timeout)
  File "/usr/lib/python3.10/subprocess.py", line 1953, in _wait
    time.sleep(delay)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2686085 got signal: 2

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "xla/venv/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
    run(args)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
    result = agent.run()
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
    result = f(*args, **kwargs)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 710, in run
    self._shutdown()
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
    self._pcontext.close(death_sig)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
    self._close(death_sig=death_sig, timeout=timeout)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
    handler.proc.wait(time_to_wait)
  File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
    return self._wait(timeout=timeout)
  File "/usr/lib/python3.10/subprocess.py", line 1953, in _wait
    time.sleep(delay)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2686085 got signal: 2

As mentioned before it seems it just simply hangs on the torch_xla.launch

@yaochengji
Copy link
Collaborator

Then it should be the GPU-specific issue. If PJRT_DEVICE=CPU is set, it will work.

@Patataman
Copy link
Author

Yeah, it seems to be a GPU related problem. Any idea how to approach this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants