From ab68a1b9f90a5724230ad1ba77e6b2b80a873343 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 15 Jan 2025 15:12:49 +0100 Subject: [PATCH 1/2] fix: Lighteval communication with TGI --- src/lighteval/models/endpoints/tgi_model.py | 20 ++++++++++++++++++-- src/lighteval/models/model_loader.py | 4 +--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index f0bb712b..903241ef 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -101,7 +101,7 @@ def __init__(self, config: TGIModelConfig) -> None: model_name = str(self.model_info["model_id"]) model_sha = self.model_info["model_sha"] - model_precision = self.model_info["model_dtype"] + model_precision = self.model_info.get("model_dtype") self.model_info = ModelInfo( model_name=model_name, model_sha=model_sha, @@ -127,7 +127,23 @@ def _async_process_request( grammar=grammar, ) - generated_text = self.client.generate(prompt=context, generation_config=generation_config) + generated_text = self.client.generate( + prompt=context, + do_sample=generation_config.do_sample or False, + max_new_tokens=generation_config.max_new_tokens, + best_of=generation_config.best_of, + repetition_penalty=generation_config.repetition_penalty, + return_full_text=generation_config.return_full_text or False, + seed=generation_config.seed, + stop_sequences=generation_config.stop, + temperature=generation_config.temperature, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + truncate=generation_config.truncate, + typical_p=generation_config.typical_p, + watermark=generation_config.watermark or False, + decoder_input_details=generation_config.decoder_input_details, + ) return generated_text diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 68835fda..d24e8404 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -108,9 +108,7 @@ def load_model_with_tgi(config: TGIModelConfig): raise ImportError(NO_TGI_ERROR_MSG) logger.info(f"Load model from inference server: {config.inference_server_address}") - model = ModelClient( - address=config.inference_server_address, auth_token=config.inference_server_auth, model_id=config.model_id - ) + model = ModelClient(config=config) return model From f442a29aaf9e74a0b42cb6c2ad7d5d3e25521a07 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 15 Jan 2025 18:00:27 +0100 Subject: [PATCH 2/2] fix: JSON grammar constrained generation --- pyproject.toml | 2 +- src/lighteval/models/endpoints/endpoint_model.py | 2 ++ src/lighteval/models/endpoints/tgi_model.py | 1 + src/lighteval/models/model_input.py | 2 ++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f60e610e..a4729105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ dependencies = [ [project.optional-dependencies] litellm = ["litellm", "diskcache"] -tgi = ["text-generation==0.6.0"] +tgi = ["text-generation==0.7.0"] optimum = ["optimum==1.12.0"] quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"] adapters = ["peft==0.3.0"] diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 37bb9754..ad8d7968 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -478,6 +478,7 @@ async def _async_process_batch_logprob( context=request.context if rolling else request.context + request.choice, stop_tokens=[], max_tokens=1, + grammar=request.generation_grammar, ) for request in requests ] @@ -491,6 +492,7 @@ def _process_batch_logprob( context=request.context if rolling else request.context + request.choice, stop_tokens=[], max_tokens=1, + grammar=request.generation_grammar, ) for request in requests ] diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 903241ef..fc1083aa 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -143,6 +143,7 @@ def _async_process_request( typical_p=generation_config.typical_p, watermark=generation_config.watermark or False, decoder_input_details=generation_config.decoder_input_details, + grammar=generation_config.grammar, ) return generated_text diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index 04e35be1..c552a7ae 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -42,6 +42,7 @@ class GenerationParameters: min_p: Optional[float] = None # vllm, transformers top_p: Optional[int] = None # vllm, transformers, tgi truncate_prompt: Optional[bool] = None # vllm, tgi + grammar: Optional[str] = None # tgi @classmethod def from_dict(cls, config_dict: dict): @@ -117,5 +118,6 @@ def to_tgi_ie_dict(self) -> dict: "top_k": self.top_k, "top_p": self.top_p, "truncate": self.truncate_prompt, + "grammar": self.grammar, } return {k: v for k, v in args.items() if v is not None}