Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run XLA container with DDP in Vertex AI #8588

Open
SteshinSS opened this issue Jan 17, 2025 · 0 comments
Open

Run XLA container with DDP in Vertex AI #8588

SteshinSS opened this issue Jan 17, 2025 · 0 comments

Comments

@SteshinSS
Copy link

❓ Questions and Help

Hey there! I prepared a Docker container that trains a model using DDP, which works fine in a TPU VM. However, when I run the training job in Vertex AI, it fails. I suspect it's because the --privileged --net host --shm-size=16G parameters are not available for the container in Vertex AI. Is there a way to run the container without these parameters, or is there a workaround for Vertex AI?

I also prepared a minimal example.
run.py:

import torch_xla

def mp_fn(index):
    print(str(index) + ' is ready.')

if __name__ == '__main__':
    torch_xla.launch(
        mp_fn,
        args=()
    )

Dockerfile:

FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm

COPY run.py /app/run.py
WORKDIR /app/

RUN export PJRT_DEVICE=TPU

ENTRYPOINT ["python"]
CMD ["/app/run.py"]

I create v5litepod-8 TPU VM according to docs and run the container as:
sudo docker run --rm --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/my_registry/tpu_fail_example:latest it works alright.

Now to run the same in Vertex AI
train-job-spec.yaml:

workerPoolSpecs:
  machineSpec:
    machineType: ct5lp-hightpu-8t
    tpuTopology: 2x4

  replicaCount: 1
  containerSpec:
    imageUri: us-central1-docker.pkg.dev/my_registry/tpu_fail_example:latest

And run it:

gcloud ai custom-jobs create \
  --region=us-central1 \
  --display-name=$HOSTNAME-tpu-fail \
  --config=train-job-spec.yaml

It results in error:

ERROR 2025-01-15T11:03:07.776877384Z [resource.labels.taskName: workerpool0-0] concurrent.futures.process._RemoteTraceback:
ERROR 2025-01-15T11:03:07.776892524Z [resource.labels.taskName: workerpool0-0] """
ERROR 2025-01-15T11:03:07.776899374Z [resource.labels.taskName: workerpool0-0] Traceback (most recent call last):
ERROR 2025-01-15T11:03:07.776904664Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
ERROR 2025-01-15T11:03:07.776919484Z [resource.labels.taskName: workerpool0-0] r = call_item.fn(*call_item.args, **call_item.kwargs)
ERROR 2025-01-15T11:03:07.776924384Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
ERROR 2025-01-15T11:03:07.776928944Z [resource.labels.taskName: workerpool0-0] return [fn(*args) for args in chunk]
ERROR 2025-01-15T11:03:07.776935634Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
ERROR 2025-01-15T11:03:07.776940274Z [resource.labels.taskName: workerpool0-0] return [fn(*args) for args in chunk]
ERROR 2025-01-15T11:03:07.776945034Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 58, in _run_thread_per_device
ERROR 2025-01-15T11:03:07.776951384Z [resource.labels.taskName: workerpool0-0] initializer_fn(local_rank, local_world_size)
ERROR 2025-01-15T11:03:07.776955894Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 121, in initialize_multiprocess
ERROR 2025-01-15T11:03:07.776960434Z [resource.labels.taskName: workerpool0-0] devices = xm.get_xla_supported_devices()
ERROR 2025-01-15T11:03:07.776972114Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 93, in get_xla_supported_devices
ERROR 2025-01-15T11:03:07.776977254Z [resource.labels.taskName: workerpool0-0] devices = torch_xla._XLAC._xla_get_devices()
ERROR 2025-01-15T11:03:07.776981934Z [resource.labels.taskName: workerpool0-0] RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to localhost:8482.
ERROR 2025-01-15T11:03:07.776987123Z [resource.labels.taskName: workerpool0-0] """
ERROR 2025-01-15T11:03:07.776993474Z [resource.labels.taskName: workerpool0-0] {"levelname":"ERROR", "message":""}
ERROR 2025-01-15T11:03:07.776998343Z [resource.labels.taskName: workerpool0-0] The above exception was the direct cause of the following exception:
ERROR 2025-01-15T11:03:07.777002583Z [resource.labels.taskName: workerpool0-0] {"levelname":"ERROR", "message":""}
ERROR 2025-01-15T11:03:07.777008234Z [resource.labels.taskName: workerpool0-0] Traceback (most recent call last):
ERROR 2025-01-15T11:03:07.777013183Z [resource.labels.taskName: workerpool0-0] File "/app/tpu_minimal_fail/run.py", line 11, in <module>
ERROR 2025-01-15T11:03:07.777017814Z [resource.labels.taskName: workerpool0-0] torch_xla.launch(
ERROR 2025-01-15T11:03:07.777023334Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 233, in launch
ERROR 2025-01-15T11:03:07.777027923Z [resource.labels.taskName: workerpool0-0] xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
ERROR 2025-01-15T11:03:07.777032363Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 37, in spawn
ERROR 2025-01-15T11:03:07.777037854Z [resource.labels.taskName: workerpool0-0] return pjrt.spawn(fn, nprocs, start_method, args)
ERROR 2025-01-15T11:03:07.777042434Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 209, in spawn
ERROR 2025-01-15T11:03:07.777046974Z [resource.labels.taskName: workerpool0-0] run_multiprocess(spawn_fn, start_method=start_method)
ERROR 2025-01-15T11:03:07.777052403Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
ERROR 2025-01-15T11:03:07.777056874Z [resource.labels.taskName: workerpool0-0] replica_results = list(
ERROR 2025-01-15T11:03:07.777061463Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
ERROR 2025-01-15T11:03:07.777065834Z [resource.labels.taskName: workerpool0-0] itertools.chain.from_iterable(
ERROR 2025-01-15T11:03:07.777072063Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
ERROR 2025-01-15T11:03:07.777076694Z [resource.labels.taskName: workerpool0-0] for element in iterable:
ERROR 2025-01-15T11:03:07.777081323Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
ERROR 2025-01-15T11:03:07.777087063Z [resource.labels.taskName: workerpool0-0] yield _result_or_cancel(fs.pop())
ERROR 2025-01-15T11:03:07.777091583Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
ERROR 2025-01-15T11:03:07.777096314Z [resource.labels.taskName: workerpool0-0] return fut.result(timeout)
ERROR 2025-01-15T11:03:07.777103174Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
ERROR 2025-01-15T11:03:07.777107674Z [resource.labels.taskName: workerpool0-0] return self.__get_result()
ERROR 2025-01-15T11:03:07.777112274Z [resource.labels.taskName: workerpool0-0] File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
ERROR 2025-01-15T11:03:07.777117703Z [resource.labels.taskName: workerpool0-0] raise self._exception
ERROR 2025-01-15T11:03:07.777122323Z [resource.labels.taskName: workerpool0-0] RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to localhost:8482.
ERROR 2025-01-15T11:03:24.745804846Z [resource.labels.taskName: service] The replica workerpool0-0 exited with a non-zero status of 1. To find out more about why your job exited please check the logs: https://console.cloud.google.com/logs/viewer?project=199759238457&resource=ml_job%2Fjob_id%2F427393054618419200&advancedFilter=resource.type%3D%22ml_job%22%0Aresource.labels.job_id%3D%22427393054618419200%22
INFO 2025-01-15T11:05:55.554173725Z [resource.labels.taskName: service] Job failed.

Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant