Skip to content

Commit

Permalink
feat(remote_model): support variable remote backend for model loader
Browse files Browse the repository at this point in the history
Signed-off-by: wangyu <[email protected]>
  • Loading branch information
DellCurry committed Feb 28, 2025
1 parent b3f7aac commit 1dc1523
Show file tree
Hide file tree
Showing 20 changed files with 730 additions and 128 deletions.
12 changes: 12 additions & 0 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)

from vllm.connector import create_remote_connector
from vllm.transformers_utils.utils import is_remote_url

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


Expand Down Expand Up @@ -436,6 +439,15 @@ def get_model(pretrained_model_name_or_path: str) -> str:
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])

return model_path

if is_remote_url(pretrained_model_name_or_path):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(pretrained_model_name_or_path)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
return client.get_local_dir()

return pretrained_model_name_or_path


Expand Down
53 changes: 53 additions & 0 deletions examples/offline_inference/save_remote_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_remote_state.py \
--model /path/to/load \
--tensor-parallel-size 8 \
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \
Then, the model can be loaded with
llm = LLM(
model="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name] \
tensor_parallel_size=8,
)
"""
import dataclasses
from pathlib import Path

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser

parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)

parser.add_argument("--remote-model-save-url",
required=True,
type=str,
help="remote address to store model weights")


def main(args):
engine_args = EngineArgs.from_cli_args(args)
if engine_args.enable_lora:
raise ValueError("Saving with enable_lora=True is not supported!")
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_remote_state(url=args.remote_model_save_url)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
37 changes: 20 additions & 17 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.connector import create_remote_connector
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
Expand All @@ -34,8 +35,7 @@
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3
from vllm.transformers_utils.utils import is_remote_url
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, random_uuid, resolve_obj_by_qualname)

Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
self.maybe_pull_model_tokenizer_from_remote(model, tokenizer)

if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
Expand Down Expand Up @@ -426,30 +426,32 @@ def registry(self):
def architectures(self) -> list[str]:
return getattr(self.hf_config, "architectures", [])

def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
def maybe_pull_model_tokenizer_from_remote(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary
directory in case of S3.
directory in case of remote.
Args:
model: The model name or path.
tokenizer: The tokenizer name or path.
"""
if is_s3(model) or is_s3(tokenizer):
if is_s3(model):
s3_model = S3Model()
s3_model.pull_files(
model, allow_pattern=["*.model", "*.py", "*.json"])
if is_remote_url(model) or is_remote_url(tokenizer):
logger.info("Pulling model and tokenizer from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(model)
if is_remote_url(model):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = s3_model.dir
self.model = client.get_local_dir()

if is_s3(tokenizer):
s3_tokenizer = S3Model()
s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = s3_tokenizer.dir
if is_remote_url(tokenizer):
client.pull_files(
ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = client.get_local_dir()

def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
Expand Down Expand Up @@ -1245,6 +1247,7 @@ def create_config(

class LoadFormat(str, enum.Enum):
AUTO = "auto"
REMOTE = "remote"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
Expand Down
50 changes: 50 additions & 0 deletions vllm/connector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0

import enum

from vllm.connector.base_connector import (BaseConnector, BaseFileConnector,
BaseKVConnector)
from vllm.connector.redis import RedisConnector
from vllm.connector.s3 import S3Connector
from vllm.logger import init_logger
from vllm.transformers_utils.utils import parse_connector_type

logger = init_logger(__name__)


class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"


def create_remote_connector(url, device="cpu") -> BaseConnector:
connector_type = parse_connector_type(url)
match connector_type:
case "redis":
return RedisConnector(url)
case "s3":
return S3Connector(url)
case _:
raise ValueError(f"Invalid connector type: {url}")


def get_connector_type(client: BaseConnector) -> ConnectorType:

Check failure on line 31 in vllm/connector/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Pattern matching is only supported in Python 3.10 and greater [syntax]

Check failure on line 31 in vllm/connector/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Pattern matching is only supported in Python 3.10 and greater [syntax]

Check failure on line 31 in vllm/connector/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Pattern matching is only supported in Python 3.10 and greater [syntax]

Check failure on line 31 in vllm/connector/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Pattern matching is only supported in Python 3.10 and greater [syntax]
if isinstance(client, BaseKVConnector):
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS

raise ValueError(f"Invalid connector type: {client}")


__all__ = [
"BaseConnector",
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"HPKVConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
"get_connector_type",
]
112 changes: 112 additions & 0 deletions vllm/connector/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0

import os
import shutil
import signal
import tempfile
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple

import torch


class BaseConnector(ABC):
'''
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
'''

def __init__(self, url: str, device: torch.device = "cpu"):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))

def get_local_dir(self):
return self.local_dir

@abstractmethod
def weight_iterator(
self,
rank: int = 0) -> Generator[Tuple[str, torch.Tensor], None, None]:
raise NotImplementedError()

@abstractmethod
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
raise NotImplementedError()

def close(self):
if self.closed:
return

self.closed = True
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def __del__(self):
self.close()

def _close_by_signal(self, existing_handler=None):

def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)

return new_handler


class BaseKVConnector(BaseConnector):

@abstractmethod
def get(self, key: str) -> Optional[torch.Tensor]:
raise NotImplementedError()

@abstractmethod
def getstr(self, key: str) -> Optional[str]:
raise NotImplementedError()

@abstractmethod
def set(self, key: str, obj: torch.Tensor) -> None:
raise NotImplementedError()

@abstractmethod
def setstr(self, key: str, obj: str) -> None:
raise NotImplementedError()

@abstractmethod
def list(self, prefix: str) -> List[str]:
raise NotImplementedError()


class BaseFileConnector(BaseConnector):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""

@abstractmethod
def glob(self, allow_pattern: str) -> List[str]:
raise NotImplementedError()
Loading

0 comments on commit 1dc1523

Please sign in to comment.