-
-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(remote_model): support variable remote backend for model loader
Signed-off-by: wangyu <[email protected]>
- Loading branch information
Showing
20 changed files
with
726 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
""" | ||
Saves each worker's model state dict directly to a checkpoint, which enables a | ||
fast load path for large tensor-parallel models where each worker only needs to | ||
read its own shard rather than the entire checkpoint. | ||
Example usage: | ||
python save_remote_state.py \ | ||
--model /path/to/load \ | ||
--tensor-parallel-size 8 \ | ||
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \ | ||
Then, the model can be loaded with | ||
llm = LLM( | ||
model="/path/to/save", | ||
--remote-model-url [protocol]://[host]:[port]/[model_name] \ | ||
tensor_parallel_size=8, | ||
) | ||
""" | ||
import dataclasses | ||
from pathlib import Path | ||
|
||
from vllm import LLM, EngineArgs | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
parser = FlexibleArgumentParser() | ||
EngineArgs.add_cli_args(parser) | ||
|
||
parser.add_argument("--remote-model-save-url", | ||
required=True, | ||
type=str, | ||
help="remote address to store model weights") | ||
|
||
|
||
def main(args): | ||
engine_args = EngineArgs.from_cli_args(args) | ||
if engine_args.enable_lora: | ||
raise ValueError("Saving with enable_lora=True is not supported!") | ||
model_path = engine_args.model | ||
if not Path(model_path).is_dir(): | ||
raise ValueError("model path must be a local directory") | ||
# Create LLM instance from arguments | ||
llm = LLM(**dataclasses.asdict(engine_args)) | ||
# Dump worker states to output directory | ||
model_executor = llm.llm_engine.model_executor | ||
model_executor.save_remote_state(url=args.remote_model_save_url) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import enum | ||
|
||
from vllm.connector.base_connector import (BaseConnector, BaseFileConnector, | ||
BaseKVConnector) | ||
from vllm.connector.redis import RedisConnector | ||
from vllm.connector.s3 import S3Connector | ||
from vllm.logger import init_logger | ||
from vllm.transformers_utils.utils import parse_connector_type | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class ConnectorType(str, enum.Enum): | ||
FS = "filesystem" | ||
KV = "KV" | ||
|
||
|
||
def create_remote_connector(url, device="cpu") -> BaseConnector: | ||
connector_type = parse_connector_type(url) | ||
match connector_type: | ||
case "redis": | ||
return RedisConnector(url) | ||
case "s3": | ||
return S3Connector(url) | ||
case _: | ||
raise ValueError(f"Invalid connector type: {url}") | ||
|
||
|
||
def get_connector_type(client: BaseConnector) -> ConnectorType: | ||
Check failure on line 31 in vllm/connector/__init__.py
|
||
if isinstance(client, BaseKVConnector): | ||
return ConnectorType.KV | ||
if isinstance(client, BaseFileConnector): | ||
return ConnectorType.FS | ||
|
||
raise ValueError(f"Invalid connector type: {client}") | ||
|
||
|
||
__all__ = [ | ||
"BaseConnector", | ||
"BaseFileConnector", | ||
"BaseKVConnector", | ||
"RedisConnector", | ||
"HPKVConnector", | ||
"S3Connector", | ||
"ConnectorType", | ||
"create_remote_connector", | ||
"get_connector_type", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import shutil | ||
import signal | ||
import tempfile | ||
from abc import ABC, abstractmethod | ||
from typing import Generator, List, Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
class BaseConnector(ABC): | ||
''' | ||
For fs connector such as s3: | ||
<connector_type>://<path>/<filename> | ||
For kv connector such as redis: | ||
<connector_type>://<host>:<port>/<model_name>/keys/<key> | ||
<connector_type://<host>:<port>/<model_name>/files/<filename> | ||
''' | ||
|
||
def __init__(self, url: str, device: torch.device = "cpu"): | ||
self.url = url | ||
self.device = device | ||
self.local_dir = tempfile.mkdtemp() | ||
for sig in (signal.SIGINT, signal.SIGTERM): | ||
existing_handler = signal.getsignal(sig) | ||
signal.signal(sig, self._close_by_signal(existing_handler)) | ||
|
||
def get_local_dir(self): | ||
return self.local_dir | ||
|
||
@abstractmethod | ||
def weight_iterator( | ||
self, | ||
rank: int = 0) -> Generator[Tuple[str, torch.Tensor], None, None]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def pull_files( | ||
self, | ||
allow_pattern: Optional[List[str]] = None, | ||
ignore_pattern: Optional[List[str]] = None, | ||
) -> None: | ||
raise NotImplementedError() | ||
|
||
def close(self): | ||
if os.path.exists(self.local_dir): | ||
shutil.rmtree(self.local_dir) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self.close() | ||
|
||
def __del__(self): | ||
self.close() | ||
|
||
def _close_by_signal(self, existing_handler=None): | ||
|
||
def new_handler(signum, frame): | ||
self.close() | ||
if existing_handler: | ||
existing_handler(signum, frame) | ||
|
||
return new_handler | ||
|
||
|
||
class BaseKVConnector(BaseConnector): | ||
|
||
@abstractmethod | ||
def get(self, key: str) -> Optional[torch.Tensor]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def getstr(self, key: str) -> Optional[str]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def set(self, key: str, obj: torch.Tensor) -> None: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def setstr(self, key: str, obj: str) -> None: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def list(self, prefix: str) -> List[str]: | ||
raise NotImplementedError() | ||
|
||
|
||
class BaseFileConnector(BaseConnector): | ||
""" | ||
List full file names from remote fs path and filter by allow pattern. | ||
Args: | ||
allow_pattern: A list of patterns of which files to pull. | ||
Returns: | ||
list[str]: List of full paths allowed by the pattern | ||
""" | ||
|
||
@abstractmethod | ||
def glob(self, allow_pattern: str) -> List[str]: | ||
raise NotImplementedError() |
Oops, something went wrong.