From 1dc152310e5c61e2cb27702de503944953540088 Mon Sep 17 00:00:00 2001 From: wangyu Date: Tue, 25 Feb 2025 14:56:35 +0800 Subject: [PATCH] feat(remote_model): support variable remote backend for model loader Signed-off-by: wangyu --- benchmarks/backend_request_func.py | 12 ++ .../offline_inference/save_remote_state.py | 53 +++++ vllm/config.py | 37 ++-- vllm/connector/__init__.py | 50 +++++ vllm/connector/base_connector.py | 112 +++++++++++ vllm/connector/redis.py | 85 ++++++++ .../s3_utils.py => connector/s3.py} | 103 +++------- vllm/connector/serde/__init__.py | 31 +++ vllm/connector/serde/safe_serde.py | 29 +++ vllm/connector/serde/serde.py | 43 ++++ vllm/connector/utils.py | 35 ++++ vllm/engine/arg_utils.py | 8 +- vllm/executor/executor_base.py | 3 + vllm/model_executor/model_loader/loader.py | 187 ++++++++++++++---- .../model_loader/weight_utils.py | 20 ++ vllm/transformers_utils/tokenizer.py | 11 +- vllm/transformers_utils/utils.py | 25 ++- vllm/worker/model_runner.py | 8 + vllm/worker/multi_step_model_runner.py | 3 + vllm/worker/worker.py | 3 + 20 files changed, 730 insertions(+), 128 deletions(-) create mode 100644 examples/offline_inference/save_remote_state.py create mode 100644 vllm/connector/__init__.py create mode 100644 vllm/connector/base_connector.py create mode 100644 vllm/connector/redis.py rename vllm/{transformers_utils/s3_utils.py => connector/s3.py} (54%) create mode 100644 vllm/connector/serde/__init__.py create mode 100644 vllm/connector/serde/safe_serde.py create mode 100644 vllm/connector/serde/serde.py create mode 100644 vllm/connector/utils.py diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 364b087b841d3..27d5f70f00b1b 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -14,6 +14,9 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from vllm.connector import create_remote_connector +from vllm.transformers_utils.utils import is_remote_url + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -436,6 +439,15 @@ def get_model(pretrained_model_name_or_path: str) -> str: ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) return model_path + + if is_remote_url(pretrained_model_name_or_path): + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(pretrained_model_name_or_path) + client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + return client.get_local_dir() + return pretrained_model_name_or_path diff --git a/examples/offline_inference/save_remote_state.py b/examples/offline_inference/save_remote_state.py new file mode 100644 index 0000000000000..0a439bdf81467 --- /dev/null +++ b/examples/offline_inference/save_remote_state.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_remote_state.py \ + --model /path/to/load \ + --tensor-parallel-size 8 \ + --remote-model-save-url [protocol]://[host]:[port]/[model_name] \ + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + --remote-model-url [protocol]://[host]:[port]/[model_name] \ + tensor_parallel_size=8, +) +""" +import dataclasses +from pathlib import Path + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + +parser = FlexibleArgumentParser() +EngineArgs.add_cli_args(parser) + +parser.add_argument("--remote-model-save-url", + required=True, + type=str, + help="remote address to store model weights") + + +def main(args): + engine_args = EngineArgs.from_cli_args(args) + if engine_args.enable_lora: + raise ValueError("Saving with enable_lora=True is not supported!") + model_path = engine_args.model + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = LLM(**dataclasses.asdict(engine_args)) + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_remote_state(url=args.remote_model_save_url) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/vllm/config.py b/vllm/config.py index 78d02b0173503..01a2fbe446b95 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,6 +22,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.connector import create_remote_connector from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) @@ -34,8 +35,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, try_get_generation_config, uses_mrope) -from vllm.transformers_utils.s3_utils import S3Model -from vllm.transformers_utils.utils import is_s3 +from vllm.transformers_utils.utils import is_remote_url from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, get_cpu_memory, random_uuid, resolve_obj_by_qualname) @@ -295,7 +295,7 @@ def __init__( f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) + self.maybe_pull_model_tokenizer_from_remote(model, tokenizer) if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: @@ -426,30 +426,32 @@ def registry(self): def architectures(self) -> list[str]: return getattr(self.hf_config, "architectures", []) - def maybe_pull_model_tokenizer_for_s3(self, model: str, - tokenizer: str) -> None: + def maybe_pull_model_tokenizer_from_remote(self, model: str, + tokenizer: str) -> None: """ Pull the model config or tokenizer to a temporary - directory in case of S3. + directory in case of remote. Args: model: The model name or path. tokenizer: The tokenizer name or path. """ - if is_s3(model) or is_s3(tokenizer): - if is_s3(model): - s3_model = S3Model() - s3_model.pull_files( - model, allow_pattern=["*.model", "*.py", "*.json"]) + if is_remote_url(model) or is_remote_url(tokenizer): + logger.info("Pulling model and tokenizer from remote...") + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(model) + if is_remote_url(model): + client.pull_files(allow_pattern=["*config.json"]) self.model_weights = self.model - self.model = s3_model.dir + self.model = client.get_local_dir() - if is_s3(tokenizer): - s3_tokenizer = S3Model() - s3_tokenizer.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) - self.tokenizer = s3_tokenizer.dir + if is_remote_url(tokenizer): + client.pull_files( + ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = client.get_local_dir() def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] @@ -1245,6 +1247,7 @@ def create_config( class LoadFormat(str, enum.Enum): AUTO = "auto" + REMOTE = "remote" PT = "pt" SAFETENSORS = "safetensors" NPCACHE = "npcache" diff --git a/vllm/connector/__init__.py b/vllm/connector/__init__.py new file mode 100644 index 0000000000000..bf8c8649a49c7 --- /dev/null +++ b/vllm/connector/__init__.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum + +from vllm.connector.base_connector import (BaseConnector, BaseFileConnector, + BaseKVConnector) +from vllm.connector.redis import RedisConnector +from vllm.connector.s3 import S3Connector +from vllm.logger import init_logger +from vllm.transformers_utils.utils import parse_connector_type + +logger = init_logger(__name__) + + +class ConnectorType(str, enum.Enum): + FS = "filesystem" + KV = "KV" + + +def create_remote_connector(url, device="cpu") -> BaseConnector: + connector_type = parse_connector_type(url) + match connector_type: + case "redis": + return RedisConnector(url) + case "s3": + return S3Connector(url) + case _: + raise ValueError(f"Invalid connector type: {url}") + + +def get_connector_type(client: BaseConnector) -> ConnectorType: + if isinstance(client, BaseKVConnector): + return ConnectorType.KV + if isinstance(client, BaseFileConnector): + return ConnectorType.FS + + raise ValueError(f"Invalid connector type: {client}") + + +__all__ = [ + "BaseConnector", + "BaseFileConnector", + "BaseKVConnector", + "RedisConnector", + "HPKVConnector", + "S3Connector", + "ConnectorType", + "create_remote_connector", + "get_connector_type", +] diff --git a/vllm/connector/base_connector.py b/vllm/connector/base_connector.py new file mode 100644 index 0000000000000..1625530c47cec --- /dev/null +++ b/vllm/connector/base_connector.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import shutil +import signal +import tempfile +from abc import ABC, abstractmethod +from typing import Generator, List, Optional, Tuple + +import torch + + +class BaseConnector(ABC): + ''' + For fs connector such as s3: + :/// + + For kv connector such as redis: + ://://keys/ + ://files/ + ''' + + def __init__(self, url: str, device: torch.device = "cpu"): + self.url = url + self.device = device + self.closed = False + self.local_dir = tempfile.mkdtemp() + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + + def get_local_dir(self): + return self.local_dir + + @abstractmethod + def weight_iterator( + self, + rank: int = 0) -> Generator[Tuple[str, torch.Tensor], None, None]: + raise NotImplementedError() + + @abstractmethod + def pull_files( + self, + allow_pattern: Optional[List[str]] = None, + ignore_pattern: Optional[List[str]] = None, + ) -> None: + raise NotImplementedError() + + def close(self): + if self.closed: + return + + self.closed = True + if os.path.exists(self.local_dir): + shutil.rmtree(self.local_dir) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def _close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self.close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + +class BaseKVConnector(BaseConnector): + + @abstractmethod + def get(self, key: str) -> Optional[torch.Tensor]: + raise NotImplementedError() + + @abstractmethod + def getstr(self, key: str) -> Optional[str]: + raise NotImplementedError() + + @abstractmethod + def set(self, key: str, obj: torch.Tensor) -> None: + raise NotImplementedError() + + @abstractmethod + def setstr(self, key: str, obj: str) -> None: + raise NotImplementedError() + + @abstractmethod + def list(self, prefix: str) -> List[str]: + raise NotImplementedError() + + +class BaseFileConnector(BaseConnector): + """ + List full file names from remote fs path and filter by allow pattern. + + Args: + allow_pattern: A list of patterns of which files to pull. + + Returns: + list[str]: List of full paths allowed by the pattern + """ + + @abstractmethod + def glob(self, allow_pattern: str) -> List[str]: + raise NotImplementedError() diff --git a/vllm/connector/redis.py b/vllm/connector/redis.py new file mode 100644 index 0000000000000..29cbf438aa76e --- /dev/null +++ b/vllm/connector/redis.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Generator, List, Optional, Tuple +from urllib.parse import urlparse + +import torch + +from vllm.connector import BaseKVConnector +from vllm.connector.serde import create_serde +from vllm.connector.utils import pull_files_from_db +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class RedisConnector(BaseKVConnector): + + def __init__(self, url: str, device: torch.device = "cpu"): + import redis + super().__init__(url, device) + parsed_url = urlparse(url) + self.connection = redis.Redis(host=parsed_url.hostname, + port=parsed_url.port) + self.model_name = parsed_url.path.lstrip("/") + # TODO: more serde options + self.s, self.d = create_serde("safe") + + def get(self, key: str) -> Optional[torch.Tensor]: + val = self.connection.get(key) + + if val is None: + logger.error("Key %s not found", key) + return None + + return self.d.from_bytes(val) + + def getstr(self, key: str) -> Optional[str]: + val = self.connection.get(key) + if val is None: + logger.error("Key %s not found", key) + return None + + return val.decode("utf-8") + + def set(self, key: str, tensor: torch.Tensor) -> None: + assert tensor is not None + self.connection.set(key, self.s.to_bytes(tensor)) + + def setstr(self, key: str, obj: str) -> None: + self.connection.set(key, obj) + + def list(self, prefix: str) -> List[str]: + cursor = 0 + all_keys: List[bytes] = [] + + while True: + ret: Tuple[int, List[bytes]] = self.connection.scan( + cursor=cursor, match=f"{prefix}*") # type: ignore + cursor, keys = ret + all_keys.extend(keys) + if cursor == 0: + break + + return [key.decode("utf-8") for key in all_keys] + + def weight_iterator(self, + rank: int = 0 + ) -> Generator[Tuple[str, bytes], None, None]: + keys = self.list(f"{self.model_name}/keys/rank_{rank}/") + for key in keys: + val = self.get(key) + key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/") + yield key, val + + def pull_files( + self, + allow_pattern: Optional[List[str]] = None, + ignore_pattern: Optional[List[str]] = None, + ) -> None: + pull_files_from_db(self, self.model_name, allow_pattern, + ignore_pattern) + + def close(self): + self.connection.close() + super().close() diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/connector/s3.py similarity index 54% rename from vllm/transformers_utils/s3_utils.py rename to vllm/connector/s3.py index 1c3520bcfb278..ed14d1e91dbf4 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/connector/s3.py @@ -2,18 +2,18 @@ import fnmatch import os -import shutil -import signal -import tempfile from pathlib import Path -from typing import Optional +from typing import Generator, Optional, Tuple +import torch + +from vllm.connector import BaseFileConnector from vllm.utils import PlaceholderModule try: import boto3 except ImportError: - boto3 = PlaceholderModule("boto3") # type: ignore[assignment] + boto3 = PlaceholderModule("boto3") def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: @@ -30,30 +30,6 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: ] -def glob(s3=None, - path: str = "", - allow_pattern: Optional[list[str]] = None) -> list[str]: - """ - List full file names from S3 path and filter by allow pattern. - - Args: - s3: S3 client to use. - path: The S3 path to list from. - allow_pattern: A list of patterns of which files to pull. - - Returns: - list[str]: List of full S3 paths allowed by the pattern - """ - if s3 is None: - s3 = boto3.client("s3") - if not path.endswith("/"): - path = path + "/" - bucket_name, _, paths = list_files(s3, - path=path, - allow_pattern=allow_pattern) - return [f"s3://{bucket_name}/{path}" for path in paths] - - def list_files( s3, path: str, @@ -94,44 +70,19 @@ def list_files( return bucket_name, prefix, paths -class S3Model: - """ - A class representing a S3 model mirrored into a temporary directory. +class S3Connector(BaseFileConnector): - Attributes: - s3: S3 client. - dir: The temporary created directory. + def __init__(self, url: str) -> None: + super().__init__(url) + self.client = boto3.client('s3') - Methods: - pull_files(): Pull model from S3 to the temporary directory. - """ - - def __init__(self) -> None: - self.s3 = boto3.client('s3') - for sig in (signal.SIGINT, signal.SIGTERM): - existing_handler = signal.getsignal(sig) - signal.signal(sig, self._close_by_signal(existing_handler)) - - self.dir = tempfile.mkdtemp() - - def __del__(self): - self._close() - - def _close(self) -> None: - if os.path.exists(self.dir): - shutil.rmtree(self.dir) - - def _close_by_signal(self, existing_handler=None): - - def new_handler(signum, frame): - self._close() - if existing_handler: - existing_handler(signum, frame) - - return new_handler + def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]: + bucket_name, _, paths = list_files(self.client, + path=self.url, + allow_pattern=allow_pattern) + return [f"s3://{bucket_name}/{path}" for path in paths] def pull_files(self, - s3_model_path: str = "", allow_pattern: Optional[list[str]] = None, ignore_pattern: Optional[list[str]] = None) -> None: """ @@ -143,19 +94,29 @@ def pull_files(self, ignore_pattern: A list of patterns of which files not to pull. """ - if not s3_model_path.endswith("/"): - s3_model_path = s3_model_path + "/" - - bucket_name, base_dir, files = list_files(self.s3, s3_model_path, + bucket_name, base_dir, files = list_files(self.client, self.url, allow_pattern, ignore_pattern) if len(files) == 0: return for file in files: - destination_file = os.path.join( - self.dir, - file.removeprefix(base_dir).lstrip("/")) + destination_file = os.path.join(self.local_dir, + file.removeprefix(base_dir)) local_dir = Path(destination_file).parent os.makedirs(local_dir, exist_ok=True) - self.s3.download_file(bucket_name, file, destination_file) + self.client.download_file(bucket_name, file, destination_file) + + def weight_iterator( + self, + rank: int = 0) -> Generator[Tuple[str, torch.Tensor], None, None]: + from vllm.model_executor.model_loader.weight_utils import ( + runai_safetensors_weights_iterator) + + # only support safetensor files now + hf_weights_files = self.glob(allow_pattern=["*.safetensors"]) + return runai_safetensors_weights_iterator(hf_weights_files) + + def close(self): + self.client.close() + super().close() diff --git a/vllm/connector/serde/__init__.py b/vllm/connector/serde/__init__.py new file mode 100644 index 0000000000000..12124a1d95ef3 --- /dev/null +++ b/vllm/connector/serde/__init__.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +# inspired by LMCache +from typing import Optional, Tuple + +import torch + +from vllm.connector.serde.safe_serde import SafeDeserializer, SafeSerializer +from vllm.connector.serde.serde import Deserializer, Serializer + + +def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]: + s: Optional[Serializer] = None + d: Optional[Deserializer] = None + + if serde_type == "safe": + s = SafeSerializer() + d = SafeDeserializer(torch.uint8) + else: + raise ValueError(f"Unknown serde type: {serde_type}") + + return s, d + + +__all__ = [ + "Serializer", + "Deserializer", + "SafeSerializer", + "SafeDeserializer", + "create_serde", +] diff --git a/vllm/connector/serde/safe_serde.py b/vllm/connector/serde/safe_serde.py new file mode 100644 index 0000000000000..1936f23aa6c93 --- /dev/null +++ b/vllm/connector/serde/safe_serde.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import torch +from safetensors.torch import load, save + +from vllm.connector.serde.serde import Deserializer, Serializer + + +class SafeSerializer(Serializer): + + def __init__(self): + super().__init__() + + def to_bytes(self, t: torch.Tensor) -> bytes: + return save({"tensor_bytes": t.cpu().contiguous()}) + + +class SafeDeserializer(Deserializer): + + def __init__(self, dtype): + super().__init__(dtype) + + def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor: + return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype) + + def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor: + return self.from_bytes_normal(b) diff --git a/vllm/connector/serde/serde.py b/vllm/connector/serde/serde.py new file mode 100644 index 0000000000000..3d6f804d754fc --- /dev/null +++ b/vllm/connector/serde/serde.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 + +import abc +from abc import ABC, abstractmethod + +import torch + + +class Serializer(ABC): + + @abstractmethod + def to_bytes(self, t: torch.Tensor) -> bytes: + """ + Serialize a pytorch tensor to bytes. The serialized bytes should contain + both the data and the metadata (shape, dtype, etc.) of the tensor. + + Input: + t: the input pytorch tensor, can be on any device, in any shape, + with any dtype + + Returns: + bytes: the serialized bytes + """ + raise NotImplementedError + + +class Deserializer(metaclass=abc.ABCMeta): + + def __init__(self, dtype): + self.dtype = dtype + + @abstractmethod + def from_bytes(self, bs: bytes) -> torch.Tensor: + """ + Deserialize a pytorch tensor from bytes. + + Input: + bytes: a stream of bytes + + Output: + torch.Tensor: the deserialized pytorch tensor + """ + raise NotImplementedError diff --git a/vllm/connector/utils.py b/vllm/connector/utils.py new file mode 100644 index 0000000000000..74b8fb5ee3f25 --- /dev/null +++ b/vllm/connector/utils.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +from vllm.connector import BaseConnector + + +def parse_model_name(url: str) -> str: + """ + Parse the model name from the url. + Only used for db connector + """ + parsed_url = urlparse(url) + return parsed_url.path.lstrip("/") + + +def pull_files_from_db( + connector: BaseConnector, + model_name: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, +) -> None: + prefix = f"{model_name}/files/" + local_dir = connector.get_local_dir() + files = connector.list(prefix) + + for file in files: + destination_file = os.path.join(local_dir, file.removeprefix(prefix)) + local_dir = Path(destination_file).parent + os.makedirs(local_dir, exist_ok=True) + with open(destination_file, "wb") as f: + f.write(connector.getstr(file).encode('utf-8')) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1a2f794c9151d..92eff8b52acb3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.plugins import load_general_plugins from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.utils import check_gguf_file +from vllm.transformers_utils.utils import check_gguf_file, is_remote_url from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, StoreBoolean @@ -342,7 +342,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '* "runai_streamer" will load the Safetensors weights using Run:ai' 'Model Streamer \n' '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') + 'quantization.\n' + '* "remote" will load the weights from remote database.\n') parser.add_argument( '--config-format', default=EngineArgs.config_format, @@ -1137,6 +1138,9 @@ def create_load_config(self) -> LoadConfig: "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") + if is_remote_url(self.model): + self.load_format = "remote" + return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 6f5adb4f64728..fc9a051719f1c 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -230,6 +230,9 @@ def save_sharded_state( pattern=pattern, max_size=max_size)) + def save_remote_state(self, url: str) -> None: + self.collective_rpc("save_remote_state", kwargs=dict(url=url)) + @abstractmethod def check_health(self) -> None: """Checks if the executor is healthy. If not, it should raise an diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 46247eaf2a60c..6caf297ec96e0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -10,6 +10,7 @@ import itertools import math import os +import time import warnings from abc import ABC, abstractmethod from contextlib import contextmanager @@ -28,6 +29,9 @@ from vllm.attention import Attention from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config) +from vllm.connector import (ConnectorType, create_remote_connector, + get_connector_type) +from vllm.connector.utils import parse_model_name from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE @@ -51,11 +55,10 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, - runai_safetensors_weights_iterator, safetensors_weights_iterator) + runai_safetensors_weights_iterator, safetensors_weights_iterator, + set_runai_streamer_env) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.transformers_utils.s3_utils import glob as s3_glob -from vllm.transformers_utils.utils import is_s3 from vllm.utils import is_pin_memory_available @@ -398,6 +401,8 @@ def download_model(self, model_config: ModelConfig) -> None: allow_patterns_overrides=None) def load_model(self, vllm_config: VllmConfig) -> nn.Module: + logger.info("Loading weights by default loader ... ") + start = time.perf_counter() device_config = vllm_config.device_config model_config = vllm_config.model_config target_device = torch.device(device_config.device) @@ -419,6 +424,9 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: _process_weights_after_loading(model, model_config, target_device) + end = time.perf_counter() + logger.info("Loaded weights from default loader in %.2f seconds.", + end - start) return model.eval() @@ -1313,58 +1321,35 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: class RunaiModelStreamerLoader(BaseModelLoader): """ Model loader that can load safetensors - files from local FS or S3 bucket. + files from local FS. """ def __init__(self, load_config: LoadConfig): super().__init__(load_config) - if load_config.model_loader_extra_config: - extra_config = load_config.model_loader_extra_config - - if ("concurrency" in extra_config - and isinstance(extra_config.get("concurrency"), int)): - os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( - extra_config.get("concurrency")) - - if ("memory_limit" in extra_config - and isinstance(extra_config.get("memory_limit"), int)): - os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( - extra_config.get("memory_limit")) - - runai_streamer_s3_endpoint = os.getenv( - 'RUNAI_STREAMER_S3_ENDPOINT') - aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') - if (runai_streamer_s3_endpoint is None - and aws_endpoint_url is not None): - os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + set_runai_streamer_env(load_config) def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]) -> List[str]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" - - is_s3_path = is_s3(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME - hf_folder = (model_name_or_path if - (is_local or is_s3_path) else download_weights_from_hf( + hf_folder = (model_name_or_path + if is_local else download_weights_from_hf( model_name_or_path, self.load_config.download_dir, [safetensors_pattern], revision, ignore_patterns=self.load_config.ignore_patterns, )) - if is_s3_path: - hf_weights_files = s3_glob(path=hf_folder, - allow_pattern=[safetensors_pattern]) - else: - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) - if not is_local and not is_s3_path: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision) @@ -1408,6 +1393,137 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: return model.eval() +class RemoteModelLoader(BaseModelLoader): + """Model loader that can load Tensors from remote database.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + set_runai_streamer_env(load_config) + + def _get_weights_iterator_kv( + self, + client, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights from remote storage.""" + assert get_connector_type(client) == ConnectorType.KV + rank = get_tensor_model_parallel_rank() + return client.weight_iterator(rank) + + def _get_weights_iterator_fs( + self, + client, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights from remote storage.""" + assert get_connector_type(client) == ConnectorType.FS + return client.weight_iterator() + + def download_model(self, model_config: ModelConfig) -> None: + pass + + @staticmethod + def save_model( + model: torch.nn.Module, + model_path: str, + url: str, + ) -> None: + with create_remote_connector(url) as client: + assert get_connector_type(client) == ConnectorType.KV + model_name = parse_model_name(url) + rank = get_tensor_model_parallel_rank() + state_dict = ShardedStateLoader._filter_subtensors( + model.state_dict()) + for key, tensor in state_dict.items(): + r_key = f"{model_name}/keys/rank_{rank}/{key}" + client.set(r_key, tensor) + + for root, _, files in os.walk(model_path): + for file_name in files: + # ignore hidden files + if file_name.startswith("."): + continue + if os.path.splitext(file_name)[1] not in (".bin", ".pt", + ".safetensors"): + logger.info(file_name) + file_path = os.path.join(root, file_name) + with open(file_path) as file: + file_content = file.read() + f_key = f"{model_name}/files/{file_name}" + client.setstr(f_key, file_content) + + def _load_model_from_remote_kv(self, model: nn.Module, client, + vllm_config: VllmConfig): + model_config = vllm_config.model_config + device_config = vllm_config.device_config + _process_weights_after_loading(model, model_config, + device_config.device) + weights_iterator = self._get_weights_iterator_kv(client) + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + for key, tensor in weights_iterator: + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + + def _load_model_from_remote_fs(self, model, client, + vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model.load_weights(self._get_weights_iterator_fs(client)) + _process_weights_after_loading(model, model_config, target_device) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + logger.info("Loading weights from remote storage ...") + start = time.perf_counter() + device_config = vllm_config.device_config + model_config = vllm_config.model_config + load_config = vllm_config.load_config + + assert load_config.load_format == LoadFormat.REMOTE, ( + f"Model loader {self.load_config.load_format} is not supported for " + f"load format {load_config.load_format}") + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + + with create_remote_connector(model_weights, + device_config.device) as client: + connector_type = get_connector_type(client) + if connector_type == ConnectorType.KV: + self._load_model_from_remote_kv(model, client, vllm_config) + elif connector_type == ConnectorType.FS: + self._load_model_from_remote_fs(model, client, vllm_config) + + end = time.perf_counter() + logger.info("Loaded weights from remote storage in %.2f seconds.", + end - start) + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" if isinstance(load_config.load_format, type): @@ -1431,4 +1547,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.RUNAI_STREAMER: return RunaiModelStreamerLoader(load_config) + if load_config.load_format == LoadFormat.REMOTE: + return RemoteModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 245c199f75b18..006549e2e08fa 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -428,6 +428,26 @@ def safetensors_weights_iterator( yield name, param +def set_runai_streamer_env(load_config: LoadConfig): + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv('RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if (runai_streamer_s3_endpoint is None and aws_endpoint_url is not None): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def runai_safetensors_weights_iterator( hf_weights_files: List[str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f0aa5fdcaa61f..55628bb7cace1 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,13 +12,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from vllm.connector import create_remote_connector from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_base import (TokenizerBase, TokenizerRegistry) from vllm.transformers_utils.tokenizers import MistralTokenizer -from vllm.transformers_utils.utils import check_gguf_file +from vllm.transformers_utils.utils import check_gguf_file, is_remote_url from vllm.utils import make_async if TYPE_CHECKING: @@ -161,6 +162,14 @@ def get_tokenizer( ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) tokenizer_name = tokenizer_path + if is_remote_url(tokenizer_name): + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(tokenizer_name) + client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + tokenizer_name = client.get_local_dir() + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError( diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 87e446f894384..da0bef1d26263 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -1,12 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 - +import re from os import PathLike from pathlib import Path from typing import List, Optional, Union -def is_s3(model_or_path: str) -> bool: - return model_or_path.lower().startswith('s3://') +def is_remote_url(url: str) -> bool: + """ + Check if the URL is a remote URL of the format: + ://:/ + """ + pattern = r"(.+)://(.*)" + m = re.match(pattern, url) + return m is not None + + +def parse_connector_type(url: str) -> str: + """ + Parse the connector type from the URL of the format: + :// + """ + pattern = r"(.+)://(.*)" + m = re.match(pattern, url) + if m is None: + return "" + + return m.group(1) def check_gguf_file(model: Union[str, PathLike]) -> bool: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bb2228165b528..4cdde1f60dbe3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1191,6 +1191,14 @@ def save_tensorized_model( tensorizer_config=tensorizer_config, ) + def save_remote_model(self, url: str) -> None: + from vllm.model_executor.model_loader.loader import RemoteModelLoader + RemoteModelLoader.save_model( + self.model, + self.model_config.model, + url, + ) + def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 7ddf382079c62..10f242a59c478 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -665,6 +665,9 @@ def save_tensorized_model(self, tensorizer_config: TensorizerConfig) -> None: return self._base_model_runner.save_tensorized_model(tensorizer_config) + def save_remote_model(self, url): + return self._base_model_runner.save_remote_model(url) + def profile_run(self) -> None: return self._base_model_runner.profile_run() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ad94a6a4db7a3..90534f508dee9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -201,6 +201,9 @@ def save_tensorized_model( self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + def save_remote_state(self, url: str) -> None: + self.model_runner.save_remote_model(url=url) + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many