From ad4f388c98c855549b73e88a3bfab8cbac2b6838 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 20 Nov 2023 20:37:15 -0500 Subject: [PATCH] refactor: update runner helpers and add max_model_len (#712) * chore(runner): cleanup unecessary checks for runnable backend Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: saving llm reference to runner Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: correct inject item Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update support for max_seq_len Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: correct max_model_len Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update and warning backward compatibility Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: remove unused sets Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- .../src/openllm_core/_configuration.py | 76 +++++++- openllm-python/src/openllm/_llm.py | 93 +--------- openllm-python/src/openllm/_runners.py | 174 +++++++++++------- openllm-python/src/openllm/_runners.pyi | 57 ++---- openllm-python/src/openllm_cli/entrypoint.py | 20 +- openllm-python/tests/configuration_test.py | 8 - 6 files changed, 222 insertions(+), 206 deletions(-) diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index c01ac2843..d7a286ef6 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -7,6 +7,7 @@ import sys import types import typing as t +import warnings import attr import click_option_group as cog @@ -27,9 +28,11 @@ LiteralBackend, LiteralSerialisation, LiteralString, + M, NotRequired, Required, Self, + T, overload, ) from .exceptions import ForbiddenAttributeError, MissingDependencyError @@ -42,10 +45,16 @@ import vllm from attrs import AttrsInstance + import openllm from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest else: - vllm = LazyLoader('vllm', globals(), 'vllm') + vllm = LazyLoader( + 'vllm', + globals(), + 'vllm', + exc_msg='vLLM is not installed. Make sure to install it with `pip install "openllm[vllm]"`', + ) transformers = LazyLoader('transformers', globals(), 'transformers') peft = LazyLoader('peft', globals(), 'peft') @@ -268,6 +277,9 @@ class SamplingParams(ReprMixin): top_k: int top_p: float logprobs: int + repetition_penalty: float + length_penalty: float + early_stopping: bool def __init__(self, *, _internal: bool = False, **attrs: t.Any): if not _internal: @@ -352,7 +364,7 @@ def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t. ), ) -_SamplingParamsT = t.TypeVar('_SamplingParams', bound=SamplingParams) +_SamplingParamsT = t.TypeVar('_SamplingParamsT', bound=SamplingParams) # cached it here to save one lookup per assignment _object_getattribute = object.__getattribute__ @@ -841,8 +853,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): - Automatic environment conversion: Each fields will automatically be provisioned with an environment variable, make it easy to work with ahead-of-time or during serving time - - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, - i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options`` + - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, i.e: ``model_construct_env`` - Automatic CLI generation: It can identify each fields and convert it to compatible Click options. This means developers can use any of the LLMConfig to create CLI with compatible-Python CLI library (click, typer, ...) @@ -1447,18 +1458,63 @@ def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]: def make_fine_tune_config(self, adapter_type: AdapterType, **attrs: t.Any) -> FineTuneConfig: return FineTuneConfig(adapter_type=adapter_type, llm_config_class=self.__class__).with_config(**attrs) - @overload - def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig: ... + def inference_options(self, llm: openllm.LLM[M, T], backend: str | None = None) -> tuple[Self, t.Any]: + backend = backend if backend is not None else llm.__llm_backend__ + cls = getattr(self, backend, None) + if cls is None: + raise ValueError(f'Unknown backend {backend}') + try: + return self, cls.build(self) + except AttributeError: + raise RuntimeError(f'Unknown backend {llm.__llm_backend__}') from None + + class vllm: + @staticmethod + def build(config: LLMConfig) -> vllm.SamplingParams: + if config['temperature'] <= 1e-5: + top_p = 1.0 + else: + top_p = config['top_p'] + _object_setattr(config.sampling_config, 'top_p', top_p) + return config.sampling_config.build() + + class ctranslate: + @staticmethod + def build(config: LLMConfig) -> dict[str, t.Any]: + return dict( + max_length=config['max_new_tokens'], + min_length=config['min_length'], + sampling_topk=config['top_k'], + sampling_topp=config['top_p'], + sampling_temperature=config['temperature'], + return_log_prob=config['logprobs'] > 0, + repetition_penalty=config['repetition_penalty'], + no_repeat_ngram_size=config['no_repeat_ngram_size'], + end_token=config['stop'], + ) - @overload - def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ... + class pt: + @staticmethod + def build(config: LLMConfig) -> LLMConfig: + return config + + class hf: + @staticmethod + def build(config: LLMConfig) -> transformers.GenerationConfig: + return transformers.GenerationConfig(**converter.unstructure(config.generation_config)) def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: - config = transformers.GenerationConfig(**converter.unstructure(self.generation_config)) + warnings.warn( + "'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 + ) + _, config = self.inference_options(None, 'hf') return config.to_dict() if return_as_dict else config def to_sampling_config(self) -> vllm.SamplingParams: - return self.sampling_config.build() + warnings.warn( + "'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 + ) + return self.inference_options(None, 'vllm')[-1] @overload def with_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ... diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index a11a6d0c5..90de744a2 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -2,7 +2,6 @@ import functools import logging import os -import types import typing as t import attr @@ -11,7 +10,6 @@ import bentoml import openllm -from bentoml._internal.models.model import ModelSignature from openllm_core._schemas import GenerationOutput from openllm_core._typing_compat import ( AdapterMap, @@ -34,7 +32,6 @@ ReprMixin, apply, codegen, - converter, first_not_none, flatten_attrs, gen_random_uuid, @@ -131,6 +128,7 @@ class LLM(t.Generic[M, T], ReprMixin): _adapter_map: AdapterMap | None _serialisation: LiteralSerialisation _local: bool + _max_model_len: int | None _prompt_template: PromptTemplate | None _system_message: str | None @@ -163,6 +161,7 @@ def __init__( embedded=False, dtype='auto', low_cpu_mem_usage=True, + max_model_len=None, _eager=True, **attrs, ): @@ -192,6 +191,7 @@ def __init__( adapter_map=_resolve_peft_config_type(adapter_map) if adapter_map is not None else None, serialisation=serialisation, local=_local, + max_model_len=max_model_len, prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template, system_message=system_message, LLM__model_attrs=model_attrs, @@ -327,7 +327,8 @@ def tokenizer(self): return self.__llm_tokenizer__ @property def runner(self): - if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self) + from ._runners import runner + if self.__llm_runner__ is None:self.__llm_runner__=runner(self) return self.__llm_runner__ def prepare(self,adapter_type='lora',use_gradient_checking=True,**attrs): if self.__llm_backend__!='pt':raise RuntimeError('Fine tuning is only supported for PyTorch backend.') @@ -421,6 +422,8 @@ def config(self): async def generate( self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs ) -> GenerationOutput: + if adapter_name is not None and self.__llm_backend__ != 'pt': + raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') config = self.config.model_construct_env(**attrs) texts, token_ids = [[]] * config['n'], [[]] * config['n'] final_result = None @@ -446,6 +449,9 @@ async def generate_iterator( ) -> t.AsyncGenerator[GenerationOutput, None]: from bentoml._internal.runner.runner_handle import DummyRunnerHandle + if adapter_name is not None and self.__llm_backend__ != 'pt': + raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') + if isinstance(self.runner._runner_handle, DummyRunnerHandle): if os.getenv('BENTO_PATH') is not None: raise RuntimeError('Runner client failed to set up correctly.') @@ -487,82 +493,3 @@ async def generate_iterator( previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids) delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens) yield generated.with_options(outputs=delta_outputs) - - -def _RunnerFactory( - llm, /, models=None, max_batch_size=None, max_latency_ms=None, scheduling_strategy=None, *, backend=None -): - from ._runners import runnable - - if scheduling_strategy is None: - from ._strategies import CascadingResourceStrategy - - scheduling_strategy = CascadingResourceStrategy - - backend = first_not_none(getenv('backend', default=backend), default=llm.__llm_backend__) - - models = models if models is not None else [] - try: - models.append(llm.bentomodel) - except bentoml.exceptions.NotFound as err: - raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err - - if llm._prompt_template: - prompt_template = llm._prompt_template.to_string() - elif hasattr(llm.config, 'default_prompt_template'): - prompt_template = llm.config.default_prompt_template - else: - prompt_template = None - if llm._system_message: - system_message = llm._system_message - elif hasattr(llm.config, 'default_system_message'): - system_message = llm.config.default_system_message - else: - system_message = None - return types.new_class( - llm.config.__class__.__name__[:-6] + 'Runner', - (bentoml.Runner,), - exec_body=lambda ns: ns.update( - { - 'llm_type': llm.llm_type, - 'identifying_params': llm.identifying_params, - 'llm_tag': llm.tag, - 'llm': llm, - 'config': llm.config, - 'backend': backend, - '__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}', - '__module__': llm.__module__, - '__repr__': ReprMixin.__repr__, - '__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}), - '__repr_args__': lambda _: ( - ( - 'runner_methods', - { - method.name: { - 'batchable': method.config.batchable, - 'batch_dim': method.config.batch_dim if method.config.batchable else None, - } - for method in _.runner_methods - }, - ), - ('config', llm.config.model_dump(flatten=True)), - ('llm_type', llm.llm_type), - ('backend', backend), - ('llm_tag', llm.tag), - ), - 'has_adapters': llm.has_adapters, - 'prompt_template': prompt_template, - 'system_message': system_message, - } - ), - )( - runnable(llm, backend), - name=llm.runner_name, - embedded=False, - models=models, - max_batch_size=max_batch_size, - max_latency_ms=max_latency_ms, - scheduling_strategy=scheduling_strategy, - runnable_init_params={'llm': llm}, - method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}), - ) diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index c428919f3..c9de4ceb2 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -1,6 +1,7 @@ from __future__ import annotations import gc import traceback +import types import typing as t import torch @@ -9,23 +10,90 @@ import openllm from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs from openllm_core.exceptions import OpenLLMException -from openllm_core.utils import first_not_none, getenv, is_ctranslate_available +from openllm_core.utils import ReprMixin, is_ctranslate_available, is_vllm_available -__all__ = ['runnable'] +__all__ = ['runner'] +_registry = {} + + +def registry(cls=None, *, alias=None): + def decorator(_cls): + _registry[_cls.__name__[:-8].lower() if alias is None else alias] = _cls + return _cls + + if cls is None: + return decorator + return decorator(cls) -def runnable(llm, backend=None): - backend = first_not_none(getenv('backend', default=backend), default=llm._cascade_backend()) - if backend == 'vllm': - return vLLMRunnable - elif backend == 'pt': - return PyTorchRunnable - elif backend == 'ctranslate': - return CTranslateRunnable - else: - raise OpenLLMException(f'Unsupported backend: {backend}') +def runner(llm): + from ._strategies import CascadingResourceStrategy + try: + models = [llm.bentomodel] + except bentoml.exceptions.NotFound as err: + raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err + + if llm._prompt_template: + prompt_template = llm._prompt_template.to_string() + elif hasattr(llm.config, 'default_prompt_template'): + prompt_template = llm.config.default_prompt_template + else: + prompt_template = None + if llm._system_message: + system_message = llm._system_message + elif hasattr(llm.config, 'default_system_message'): + system_message = llm.config.default_system_message + else: + system_message = None + + return types.new_class( + llm.config.__class__.__name__[:-6] + 'Runner', + (bentoml.Runner,), + exec_body=lambda ns: ns.update( + { + 'llm_type': llm.llm_type, + 'identifying_params': llm.identifying_params, + 'llm_tag': llm.tag, + 'llm': llm, + 'config': llm.config, + 'backend': llm.__llm_backend__, + '__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}', + '__module__': llm.__module__, + '__repr__': ReprMixin.__repr__, + '__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}), + '__repr_args__': lambda _: ( + ( + 'runner_methods', + { + method.name: { + 'batchable': method.config.batchable, + 'batch_dim': method.config.batch_dim if method.config.batchable else None, + } + for method in _.runner_methods + }, + ), + ('config', llm.config.model_dump(flatten=True)), + ('llm_type', llm.llm_type), + ('backend', llm.__llm_backend__), + ('llm_tag', llm.tag), + ), + 'has_adapters': llm.has_adapters, + 'prompt_template': prompt_template, + 'system_message': system_message, + } + ), + )( + _registry[llm.__llm_backend__], + name=llm.runner_name, + models=models, + scheduling_strategy=CascadingResourceStrategy, + runnable_init_params={'llm': llm}, + ) + + +@registry class CTranslateRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True @@ -33,40 +101,23 @@ class CTranslateRunnable(bentoml.Runnable): def __init__(self, llm): if not is_ctranslate_available(): raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`') - self.config = llm.config - self.model = llm.model - self.tokenizer = llm.tokenizer + self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): - if adapter_name is not None: - raise NotImplementedError('Adapter is not supported with CTranslate.') - stop_ = set() if isinstance(stop, str) and stop != '': stop_.add(stop) elif isinstance(stop, t.Iterable): stop_.update(stop) - config = self.config.model_construct_env(stop=list(stop_), **attrs) - sampling_params = dict( - max_length=config['max_new_tokens'], - min_length=config['min_length'], - sampling_topk=config['top_k'], - sampling_topp=config['top_p'], - sampling_temperature=config['temperature'], - return_log_prob=config['logprobs'] > 0, - repetition_penalty=config['repetition_penalty'], - no_repeat_ngram_size=config['no_repeat_ngram_size'], - end_token=config['stop'], - ) - cumulative_logprob = 0.0 - output_token_ids = list(prompt_token_ids) - input_len = len(prompt_token_ids) - async for request_output in self.model.async_generate_tokens( - self.tokenizer.convert_ids_to_tokens(prompt_token_ids), **sampling_params - ): - cumulative_logprob += request_output.log_prob if config['logprobs'] else 0.0 + config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm) + cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids) + tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids) + + async for request_output in self.model.async_generate_tokens(tokens, **sampling_params): + if config['logprobs']: + cumulative_logprob += request_output.log_prob output_token_ids.append(request_output.token_id) text = self.tokenizer.decode( output_token_ids[input_len:], @@ -92,34 +143,34 @@ async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapt ).model_dump_json() +@registry class vLLMRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True def __init__(self, llm): - try: - import vllm - except ImportError: - raise OpenLLMException('vLLM is not installed. Please install it via `pip install "openllm[vllm]"`.') from None - self.config = llm.config + if not is_vllm_available(): + raise OpenLLMException('vLLM is not installed. Please install it via `pip install "openllm[vllm]"`.') + import vllm + + self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer num_gpus, dev = 1, openllm.utils.device_count() if dev >= 2: num_gpus = min(dev // 2 * 2, dev) - quantization = None - if llm.quantise and llm.quantise in {'awq', 'squeezellm'}: - quantization = llm.quantise + try: self.model = vllm.AsyncLLMEngine.from_engine_args( vllm.AsyncEngineArgs( - model=llm.bentomodel.path, - tokenizer=llm.bentomodel.path, - trust_remote_code=llm.trust_remote_code, tokenizer_mode='auto', tensor_parallel_size=num_gpus, - dtype=llm._torch_dtype, - quantization=quantization, worker_use_ray=False, engine_use_ray=False, + model=llm.bentomodel.path, + tokenizer=llm.bentomodel.path, + trust_remote_code=llm.trust_remote_code, + dtype=llm._torch_dtype, + max_model_len=llm._max_model_len, + quantization=llm.quantise if llm.quantise and llm.quantise in {'awq', 'squeezellm'} else None, ) ) except Exception as err: @@ -128,21 +179,13 @@ def __init__(self, llm): @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): - if adapter_name is not None: - raise NotImplementedError('Adapter is not supported with vLLM.') stop_ = set() if isinstance(stop, str) and stop != '': stop_.add(stop) elif isinstance(stop, t.Iterable): stop_.update(stop) - temperature = attrs.pop('temperature', self.config['temperature']) - top_p = attrs.pop('top_p', self.config['top_p']) - if temperature <= 1e-5: - top_p = 1.0 - sampling_params = self.config.model_construct_env( - stop=list(stop_), temperature=temperature, top_p=top_p, **attrs - ).to_sampling_config() + config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm) async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids): # XXX: Need to write a hook for serialisation None correctly @@ -151,32 +194,31 @@ async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapt yield GenerationOutput.from_vllm(request_output).model_dump_json() +@registry(alias='pt') class PyTorchRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True def __init__(self, llm): - self.model = llm.model - self.tokenizer = llm.tokenizer - self.config = llm.config + self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer + self.is_encoder_decoder = llm.model.config.is_encoder_decoder if hasattr(llm.model, 'device'): self.device = llm.model.device else: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.is_encoder_decoder = llm.model.config.is_encoder_decoder @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): if adapter_name is not None: self.model.set_adapter(adapter_name) + stop_ = set() if isinstance(stop, str) and stop != '': stop_.add(stop) elif isinstance(stop, t.Iterable): stop_.update(stop) - async for generation_output in self.forward( - prompt_token_ids=prompt_token_ids, request_id=request_id, stop=list(stop_), **attrs - ): + + async for generation_output in self.forward(prompt_token_ids, request_id, list(stop_), **attrs): yield generation_output.model_dump_json() async def forward(self, prompt_token_ids, request_id, stop, **attrs): diff --git a/openllm-python/src/openllm/_runners.pyi b/openllm-python/src/openllm/_runners.pyi index ef685af58..1621fd8b4 100644 --- a/openllm-python/src/openllm/_runners.pyi +++ b/openllm-python/src/openllm/_runners.pyi @@ -15,10 +15,13 @@ from typing import ( final, ) +import torch +from transformers import PreTrainedModel, PreTrainedTokenizer + from bentoml import Model, Strategy, Tag from bentoml._internal.runner.runner_handle import RunnerHandle from openllm_core import LLMConfig -from openllm_core._typing_compat import LiteralBackend, M, T, overload +from openllm_core._typing_compat import LiteralBackend, M, T from ._llm import LLM @@ -27,24 +30,21 @@ try: except ImportError: AsyncLLMEngine = Any -try: - from transformers import PreTrainedModel -except ImportError: - PreTrainedModel = Any - try: from ctranslate2 import Generator, Translator except ImportError: - Translator = Any - Generator = Any + Translator = Generator = Any Mo = TypeVar('Mo') +To = TypeVar('To') -class _Runnable(Protocol[Mo]): +class _Runnable(Protocol[Mo, To]): SUPPORTED_RESOURCES: Tuple[Literal['nvidia.com/gpu', 'amd.com/gpu', 'cpu'], ...] = ... SUPPORTS_CPU_MULTI_THREADING: bool = ... + llm: LLM[Mo, To] = ... config: LLMConfig = ... model: Mo = ... + tokenizer: To = ... def __init__(self, llm: LLM[Mo, T]) -> None: ... async def generate_iterator( self, @@ -61,42 +61,25 @@ Ret = TypeVar('Ret') class RunnerMethod(Generic[In, Ret]): ... @final -class vLLMRunnable(_Runnable[AsyncLLMEngine]): ... +class vLLMRunnable(_Runnable[AsyncLLMEngine, PreTrainedTokenizer]): ... @final -class CTranslateRunnable(_Runnable[Union[Translator, Generator]]): - tokenizer: Any +class CTranslateRunnable(_Runnable[Union[Translator, Generator], PreTrainedTokenizer]): ... @final -class PyTorchRunnable(_Runnable[PreTrainedModel]): - tokenizer: Any - async def forward( - self, - prompt_token_ids: List[int], - request_id: str, - stop: Iterable[str], - adapter_name: Optional[str] = ..., - **attrs: Any, - ) -> AsyncGenerator[str, None]: ... +class PyTorchRunnable(_Runnable[PreTrainedModel, PreTrainedTokenizer]): + is_encoder_decoder: bool = ... + device: torch.device = ... + +def runner(llm: LLM[M, T]) -> Runner[M, T]: ... -@overload -def runnable(llm: LLM[M, T], backend: Literal['vllm']) -> Type[vLLMRunnable]: ... -@overload -def runnable(llm: LLM[M, T], backend: Literal['pt']) -> Type[PyTorchRunnable]: ... -@overload -def runnable(llm: LLM[M, T], backend: Literal['ctranslate']) -> Type[CTranslateRunnable]: ... -@overload -def runnable( - llm: LLM[M, T], backend: Optional[str] = ... -) -> Type[Union[vLLMRunnable, PyTorchRunnable, CTranslateRunnable]]: ... - -class Runner(Protocol[Mo, T]): +class Runner(Protocol[Mo, To]): __doc__: str = ... __module__: str = ... llm_type: str = ... llm_tag: Tag = ... identifying_params: Dict[str, Any] = ... - llm: LLM[Mo, T] = ... + llm: LLM[Mo, To] = ... config: LLMConfig = ... backend: LiteralBackend = ... has_adapters: bool = ... @@ -115,7 +98,7 @@ class Runner(Protocol[Mo, T]): def __init__( self, - runnable_class: Type[_Runnable[Mo]], + runnable_class: Type[_Runnable[Mo, To]], *, runnable_init_params: Optional[Dict[str, Any]] = ..., name: Optional[str] = ..., @@ -130,7 +113,7 @@ class Runner(Protocol[Mo, T]): name: str = ... models: List[Model] = ... resource_config: Dict[str, Any] - runnable_class: Type[_Runnable[Mo]] + runnable_class: Type[_Runnable[Mo, To]] embedded: bool runner_methods: List[RunnerMethod[Any, Any]] scheduling_strategy: Type[Strategy] diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 9a1a65943..8cfadcd83 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -392,6 +392,13 @@ def cli() -> None: metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', help='Deprecated. Use positional argument instead.', ) +@click.option( + '--max-model-len', + '--max_model_len', + 'max_model_len', + default=None, + help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.', +) @start_decorator(serve_grpc=False) def start_command( model_id: str, @@ -409,6 +416,7 @@ def start_command( return_process: bool, dtype: LiteralDtype, deprecated_model_id: str | None, + max_model_len: int | None, **attrs: t.Any, ) -> LLMConfig | subprocess.Popen[bytes]: """Start any LLM as a REST server. @@ -466,6 +474,7 @@ def start_command( quantize=quantize, serialisation=serialisation, dtype=dtype, + max_model_len=max_model_len, ) backend_warning(llm.__llm_backend__) @@ -521,9 +530,14 @@ def start_command( help='Deprecated. Use positional argument instead.', ) @start_decorator(serve_grpc=True) -@click.pass_context +@click.option( + '--max-model-len', + '--max_model_len', + 'max_model_len', + default=None, + help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.', +) def start_grpc_command( - ctx: click.Context, model_id: str, server_timeout: int, model_version: str | None, @@ -539,6 +553,7 @@ def start_grpc_command( adapter_id: str | None, return_process: bool, deprecated_model_id: str | None, + max_model_len: int | None, **attrs: t.Any, ) -> LLMConfig | subprocess.Popen[bytes]: """Start any LLM as a gRPC server. @@ -596,6 +611,7 @@ def start_grpc_command( quantize=quantize, serialisation=serialisation, dtype=dtype, + max_model_len=max_model_len, trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'), ) backend_warning(llm.__llm_backend__) diff --git a/openllm-python/tests/configuration_test.py b/openllm-python/tests/configuration_test.py index 747d6f61c..851a53220 100644 --- a/openllm-python/tests/configuration_test.py +++ b/openllm-python/tests/configuration_test.py @@ -7,7 +7,6 @@ import attr import pytest -import transformers from hypothesis import assume, given, strategies as st import openllm @@ -165,13 +164,6 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat assert sent.field1 == 20.0 -@given(model_settings()) -@pytest.mark.parametrize(('return_dict', 'typ'), [(True, dict), (False, transformers.GenerationConfig)]) -def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings): - cl_ = make_llm_config('ConversionLLM', gen_settings) - assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ) - - @given(model_settings()) def test_click_conversion(gen_settings: ModelSettings): # currently our conversion omit Union type.