Skip to content

Commit

Permalink
[Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#4024)

[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (vllm-project#4024)
  • Loading branch information
youkaichao authored Apr 17, 2024
1 parent 11d652b commit 8438e05
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 116 deletions.
7 changes: 3 additions & 4 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import multiprocessing
import os

import pytest
import torch

from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.utils import update_environment_variables


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
Expand All @@ -32,8 +32,7 @@ def update_env(fn):
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
update_environment_variables(env)
fn()

return wrapper
Expand Down
44 changes: 7 additions & 37 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,28 @@
import pickle
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.worker.worker import Worker
from vllm.utils import get_ip, is_hip
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)

try:
import ray

class RayWorkerVllm:
class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self._worker: Optional[Worker] = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False

def init_worker(self, worker_init_fn: Callable[[], Worker]):
self._worker = worker_init_fn()

@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker

def __getattr__(self, name):
return getattr(self.worker, name)

def execute_method(self, method, *args, **kwargs):
try:
executor = getattr(self, method)
return executor(*args, **kwargs)
except Exception as e:
# exceptions in ray worker may cause deadlock
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e

def get_node_ip(self) -> str:
return get_ip()

Expand All @@ -58,9 +31,6 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids

def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)

def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
Expand All @@ -77,7 +47,7 @@ def execute_model_compiled_dag_remote(self, ignored):
"For distributed inference, please install Ray with "
"`pip install ray`.")
ray = None # type: ignore
RayWorkerVllm = None # type: ignore
RayWorkerWrapper = None # type: ignore


def initialize_ray_cluster(
Expand Down
156 changes: 84 additions & 72 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import asyncio
import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.engine.ray_utils import RayWorkerWrapper, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async, set_cuda_visible_devices)
make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -74,9 +73,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerVllm = None
self.driver_dummy_worker: RayWorkerWrapper = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = []
self.workers: List[RayWorkerWrapper] = []

if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
Expand All @@ -97,13 +96,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
else:
# Else, added to the list of workers.
self.workers.append(worker)
Expand All @@ -115,82 +121,56 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
"GPU node.")

# Get the set of GPU IDs used on each node.
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

node_workers = defaultdict(list)
node_gpus = defaultdict(list)

node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id]))
}])
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)

distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())

# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker

model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
load_config = copy.deepcopy(self.load_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
vision_language_config = copy.deepcopy(self.vision_language_config)

# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids),
start=1,
):
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs

init_worker_all_kwargs = []

# Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
init_worker_all_kwargs.append(
collect_arg_helper_func(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
))

# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=driver_local_rank,
rank=driver_rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
load_config=self.load_config,
is_driver_worker=True,
)
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
self._run_workers(
Expand Down Expand Up @@ -279,13 +259,35 @@ def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_args: Optional[Tuple[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[List[Any]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
all_args and all_kwargs are used to pass heterogeneous arguments,
i.e. different arguments for each worker.
"""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# for mypy type checking
assert driver_args is not None
assert driver_kwargs is not None
if all_args is None:
all_args = [driver_args] + [args] * len(self.workers)
if all_kwargs is None:
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)

# for mypy type checking
assert all_args is not None
assert all_kwargs is not None

if max_concurrent_workers:
raise NotImplementedError(
Expand All @@ -299,8 +301,10 @@ def _run_workers(
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_args[1:], all_kwargs[1:])
]

if driver_args is None:
Expand All @@ -309,9 +313,13 @@ def _run_workers(
driver_kwargs = kwargs

# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)

if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *all_args[0], **all_kwargs[0])
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *all_args[0], **all_kwargs[0]))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
Expand Down Expand Up @@ -386,8 +394,12 @@ async def _run_workers_async(
driver_kwargs = kwargs

# Run the driver worker asynchronously.
driver_executor = make_async(getattr(self.driver_worker, method))
coros.append(driver_executor(*driver_args, **driver_kwargs))
def helper():
return self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)

driver_executor = make_async(helper)
coros.append(driver_executor())

# Run the ray workers asynchronously.
for worker in self.workers:
Expand Down
16 changes: 14 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,12 @@ def get_open_port() -> int:
return s.getsockname()[1]


def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v


def chunk_list(lst, chunk_size):
Expand Down Expand Up @@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
merged_dict[key].extend(value)

return dict(merged_dict)


def init_cached_hf_modules():
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
Loading

0 comments on commit 8438e05

Please sign in to comment.