From 5115403153a1f3a50cbf068bb07a14571160cc80 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Tue, 17 Dec 2024 14:56:05 +0100 Subject: [PATCH] Made litellm inference robust to content management errors. --- src/lighteval/models/litellm_model.py | 58 +++++++++++++++++++-------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py index 35f30ee7d..4ac19229f 100644 --- a/src/lighteval/models/litellm_model.py +++ b/src/lighteval/models/litellm_model.py @@ -27,6 +27,7 @@ from dataclasses import dataclass from typing import Optional +from litellm.utils import ModelResponse from tqdm import tqdm from transformers import AutoTokenizer @@ -92,28 +93,41 @@ def __init__(self, config, env_config) -> None: litellm.drop_params = True litellm.verbose = True + def _prepare_stop_sequence(self, stop_sequence): + """Prepare and validate stop sequence.""" + if self.provider == "anthropic": + # Filter out whitespace-only stop sequences + if stop_sequence: + stop_sequence = [s for s in stop_sequence if s.strip()] + if not stop_sequence: # If empty after filtering + stop_sequence = ["\n"] + return stop_sequence + + def _prepare_max_new_tokens(self, max_new_tokens): + """Calculate completion tokens based on max_new_tokens.""" + if not max_new_tokens or max_new_tokens <= 0: + return None + + if "o1" in self.model: + # We need to allow more tokens to include reasoning tokens + max_new_tokens = min(max_new_tokens * 10, 32000) + return max_new_tokens + def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt): + """Make API call with retries.""" + response = ModelResponse() for attempt in range(self.API_MAX_RETRY): try: - if self.provider == "anthropic": - # Filter out whitespace-only stop sequences - if stop_sequence: - stop_sequence = [s for s in stop_sequence if s.strip()] - if not stop_sequence: # If empty after filtering - stop_sequence = ["\n"] - - # Handle max_new_tokens - completion_tokens = None - if max_new_tokens and max_new_tokens > 0: - completion_tokens = max_new_tokens - if "o1" in self.model: - # We need to allow more tokens to include reasoning tokens - completion_tokens = min(max_new_tokens * 10, 32000) + stop_sequence = self._prepare_stop_sequence(stop_sequence) + max_new_tokens = self._prepare_max_new_tokens(max_new_tokens) + + # Remove system prompt from the main prompt if it's duplicated + prompt = prompt.replace(system_prompt, "") response = litellm.completion( model=self.model, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], - max_completion_tokens=completion_tokens, + max_completion_tokens=max_new_tokens, logprobs=return_logits if self.provider == "openai" else None, stop=stop_sequence, base_url=self.base_url, @@ -123,6 +137,14 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se caching=True, ) return response + except litellm.BadRequestError as e: + if "message" in e.__dict__: + error_string = ( + "The response was filtered due to the prompt triggering Microsoft's content management policy" + ) + if error_string in e.__dict__["message"]: + logger.warning(f"{error_string}. Returning empty response.") + return ModelResponse() except Exception as e: wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s logger.warning( @@ -130,7 +152,8 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se ) time.sleep(wait_time) - logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, skipping entry.") + logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, returning empty response.") + return ModelResponse() def __call_api_parallel( self, @@ -220,7 +243,8 @@ def greedy_until( result: list[str] = [choice.message.content for choice in response.choices] cur_response = GenerativeResponse( - result=result, + # In empty responses, the model should return an empty string instead of None + result=result if result[0] else [""], logits=None, generated_tokens=[], input_tokens=[],