Skip to content

Commit

Permalink
refactor: update runner helpers and add max_model_len (#712)
Browse files Browse the repository at this point in the history
* chore(runner): cleanup unecessary checks for runnable backend

Signed-off-by: Aaron <[email protected]>

* chore: saving llm reference to runner

Signed-off-by: Aaron <[email protected]>

* chore: correct inject item

Signed-off-by: Aaron <[email protected]>

* chore: update support for max_seq_len

Signed-off-by: Aaron <[email protected]>

* fix: correct max_model_len

Signed-off-by: Aaron <[email protected]>

* chore: update and warning backward compatibility

Signed-off-by: Aaron <[email protected]>

* chore: remove unused sets

Signed-off-by: Aaron <[email protected]>

---------

Signed-off-by: Aaron <[email protected]>
  • Loading branch information
aarnphm authored Nov 21, 2023
1 parent 8fc5f1f commit ad4f388
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 206 deletions.
76 changes: 66 additions & 10 deletions openllm-core/src/openllm_core/_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import types
import typing as t
import warnings

import attr
import click_option_group as cog
Expand All @@ -27,9 +28,11 @@
LiteralBackend,
LiteralSerialisation,
LiteralString,
M,
NotRequired,
Required,
Self,
T,
overload,
)
from .exceptions import ForbiddenAttributeError, MissingDependencyError
Expand All @@ -42,10 +45,16 @@
import vllm
from attrs import AttrsInstance

import openllm
from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest
from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest
else:
vllm = LazyLoader('vllm', globals(), 'vllm')
vllm = LazyLoader(
'vllm',
globals(),
'vllm',
exc_msg='vLLM is not installed. Make sure to install it with `pip install "openllm[vllm]"`',
)
transformers = LazyLoader('transformers', globals(), 'transformers')
peft = LazyLoader('peft', globals(), 'peft')

Expand Down Expand Up @@ -268,6 +277,9 @@ class SamplingParams(ReprMixin):
top_k: int
top_p: float
logprobs: int
repetition_penalty: float
length_penalty: float
early_stopping: bool

def __init__(self, *, _internal: bool = False, **attrs: t.Any):
if not _internal:
Expand Down Expand Up @@ -352,7 +364,7 @@ def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.
),
)

_SamplingParamsT = t.TypeVar('_SamplingParams', bound=SamplingParams)
_SamplingParamsT = t.TypeVar('_SamplingParamsT', bound=SamplingParams)

# cached it here to save one lookup per assignment
_object_getattribute = object.__getattribute__
Expand Down Expand Up @@ -841,8 +853,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
- Automatic environment conversion: Each fields will automatically be provisioned with an environment
variable, make it easy to work with ahead-of-time or during serving time
- Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API,
i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options``
- Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, i.e: ``model_construct_env``
- Automatic CLI generation: It can identify each fields and convert it to compatible Click options.
This means developers can use any of the LLMConfig to create CLI with compatible-Python
CLI library (click, typer, ...)
Expand Down Expand Up @@ -1447,18 +1458,63 @@ def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]:
def make_fine_tune_config(self, adapter_type: AdapterType, **attrs: t.Any) -> FineTuneConfig:
return FineTuneConfig(adapter_type=adapter_type, llm_config_class=self.__class__).with_config(**attrs)

@overload
def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig: ...
def inference_options(self, llm: openllm.LLM[M, T], backend: str | None = None) -> tuple[Self, t.Any]:
backend = backend if backend is not None else llm.__llm_backend__
cls = getattr(self, backend, None)
if cls is None:
raise ValueError(f'Unknown backend {backend}')
try:
return self, cls.build(self)
except AttributeError:
raise RuntimeError(f'Unknown backend {llm.__llm_backend__}') from None

class vllm:
@staticmethod
def build(config: LLMConfig) -> vllm.SamplingParams:
if config['temperature'] <= 1e-5:
top_p = 1.0
else:
top_p = config['top_p']
_object_setattr(config.sampling_config, 'top_p', top_p)
return config.sampling_config.build()

class ctranslate:
@staticmethod
def build(config: LLMConfig) -> dict[str, t.Any]:
return dict(
max_length=config['max_new_tokens'],
min_length=config['min_length'],
sampling_topk=config['top_k'],
sampling_topp=config['top_p'],
sampling_temperature=config['temperature'],
return_log_prob=config['logprobs'] > 0,
repetition_penalty=config['repetition_penalty'],
no_repeat_ngram_size=config['no_repeat_ngram_size'],
end_token=config['stop'],
)

@overload
def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ...
class pt:
@staticmethod
def build(config: LLMConfig) -> LLMConfig:
return config

class hf:
@staticmethod
def build(config: LLMConfig) -> transformers.GenerationConfig:
return transformers.GenerationConfig(**converter.unstructure(config.generation_config))

def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny:
config = transformers.GenerationConfig(**converter.unstructure(self.generation_config))
warnings.warn(
"'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3
)
_, config = self.inference_options(None, 'hf')
return config.to_dict() if return_as_dict else config

def to_sampling_config(self) -> vllm.SamplingParams:
return self.sampling_config.build()
warnings.warn(
"'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3
)
return self.inference_options(None, 'vllm')[-1]

@overload
def with_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ...
Expand Down
93 changes: 10 additions & 83 deletions openllm-python/src/openllm/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import logging
import os
import types
import typing as t

import attr
Expand All @@ -11,7 +10,6 @@

import bentoml
import openllm
from bentoml._internal.models.model import ModelSignature
from openllm_core._schemas import GenerationOutput
from openllm_core._typing_compat import (
AdapterMap,
Expand All @@ -34,7 +32,6 @@
ReprMixin,
apply,
codegen,
converter,
first_not_none,
flatten_attrs,
gen_random_uuid,
Expand Down Expand Up @@ -131,6 +128,7 @@ class LLM(t.Generic[M, T], ReprMixin):
_adapter_map: AdapterMap | None
_serialisation: LiteralSerialisation
_local: bool
_max_model_len: int | None
_prompt_template: PromptTemplate | None
_system_message: str | None

Expand Down Expand Up @@ -163,6 +161,7 @@ def __init__(
embedded=False,
dtype='auto',
low_cpu_mem_usage=True,
max_model_len=None,
_eager=True,
**attrs,
):
Expand Down Expand Up @@ -192,6 +191,7 @@ def __init__(
adapter_map=_resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
serialisation=serialisation,
local=_local,
max_model_len=max_model_len,
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
system_message=system_message,
LLM__model_attrs=model_attrs,
Expand Down Expand Up @@ -327,7 +327,8 @@ def tokenizer(self):
return self.__llm_tokenizer__
@property
def runner(self):
if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self)
from ._runners import runner
if self.__llm_runner__ is None:self.__llm_runner__=runner(self)
return self.__llm_runner__
def prepare(self,adapter_type='lora',use_gradient_checking=True,**attrs):
if self.__llm_backend__!='pt':raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
Expand Down Expand Up @@ -421,6 +422,8 @@ def config(self):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
) -> GenerationOutput:
if adapter_name is not None and self.__llm_backend__ != 'pt':
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
config = self.config.model_construct_env(**attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
final_result = None
Expand All @@ -446,6 +449,9 @@ async def generate_iterator(
) -> t.AsyncGenerator[GenerationOutput, None]:
from bentoml._internal.runner.runner_handle import DummyRunnerHandle

if adapter_name is not None and self.__llm_backend__ != 'pt':
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')

if isinstance(self.runner._runner_handle, DummyRunnerHandle):
if os.getenv('BENTO_PATH') is not None:
raise RuntimeError('Runner client failed to set up correctly.')
Expand Down Expand Up @@ -487,82 +493,3 @@ async def generate_iterator(
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
yield generated.with_options(outputs=delta_outputs)


def _RunnerFactory(
llm, /, models=None, max_batch_size=None, max_latency_ms=None, scheduling_strategy=None, *, backend=None
):
from ._runners import runnable

if scheduling_strategy is None:
from ._strategies import CascadingResourceStrategy

scheduling_strategy = CascadingResourceStrategy

backend = first_not_none(getenv('backend', default=backend), default=llm.__llm_backend__)

models = models if models is not None else []
try:
models.append(llm.bentomodel)
except bentoml.exceptions.NotFound as err:
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err

if llm._prompt_template:
prompt_template = llm._prompt_template.to_string()
elif hasattr(llm.config, 'default_prompt_template'):
prompt_template = llm.config.default_prompt_template
else:
prompt_template = None
if llm._system_message:
system_message = llm._system_message
elif hasattr(llm.config, 'default_system_message'):
system_message = llm.config.default_system_message
else:
system_message = None
return types.new_class(
llm.config.__class__.__name__[:-6] + 'Runner',
(bentoml.Runner,),
exec_body=lambda ns: ns.update(
{
'llm_type': llm.llm_type,
'identifying_params': llm.identifying_params,
'llm_tag': llm.tag,
'llm': llm,
'config': llm.config,
'backend': backend,
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
'__module__': llm.__module__,
'__repr__': ReprMixin.__repr__,
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
'__repr_args__': lambda _: (
(
'runner_methods',
{
method.name: {
'batchable': method.config.batchable,
'batch_dim': method.config.batch_dim if method.config.batchable else None,
}
for method in _.runner_methods
},
),
('config', llm.config.model_dump(flatten=True)),
('llm_type', llm.llm_type),
('backend', backend),
('llm_tag', llm.tag),
),
'has_adapters': llm.has_adapters,
'prompt_template': prompt_template,
'system_message': system_message,
}
),
)(
runnable(llm, backend),
name=llm.runner_name,
embedded=False,
models=models,
max_batch_size=max_batch_size,
max_latency_ms=max_latency_ms,
scheduling_strategy=scheduling_strategy,
runnable_init_params={'llm': llm},
method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}),
)
Loading

0 comments on commit ad4f388

Please sign in to comment.