Skip to content

Commit

Permalink
Merge branch 'add_litellm_inference' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 17, 2024
2 parents 296c3de + 5115403 commit dd12702
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dataclasses import dataclass
from typing import Optional

from litellm.utils import ModelResponse
from tqdm import tqdm
from transformers import AutoTokenizer

Expand Down Expand Up @@ -92,28 +93,41 @@ def __init__(self, config, env_config) -> None:
litellm.drop_params = True
litellm.verbose = True

def _prepare_stop_sequence(self, stop_sequence):
"""Prepare and validate stop sequence."""
if self.provider == "anthropic":
# Filter out whitespace-only stop sequences
if stop_sequence:
stop_sequence = [s for s in stop_sequence if s.strip()]
if not stop_sequence: # If empty after filtering
stop_sequence = ["\n"]
return stop_sequence

def _prepare_max_new_tokens(self, max_new_tokens):
"""Calculate completion tokens based on max_new_tokens."""
if not max_new_tokens or max_new_tokens <= 0:
return None

if "o1" in self.model:
# We need to allow more tokens to include reasoning tokens
max_new_tokens = min(max_new_tokens * 10, 32000)
return max_new_tokens

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt):
"""Make API call with retries."""
response = ModelResponse()
for attempt in range(self.API_MAX_RETRY):
try:
if self.provider == "anthropic":
# Filter out whitespace-only stop sequences
if stop_sequence:
stop_sequence = [s for s in stop_sequence if s.strip()]
if not stop_sequence: # If empty after filtering
stop_sequence = ["\n"]

# Handle max_new_tokens
completion_tokens = None
if max_new_tokens and max_new_tokens > 0:
completion_tokens = max_new_tokens
if "o1" in self.model:
# We need to allow more tokens to include reasoning tokens
completion_tokens = min(max_new_tokens * 10, 32000)
stop_sequence = self._prepare_stop_sequence(stop_sequence)
max_new_tokens = self._prepare_max_new_tokens(max_new_tokens)

# Remove system prompt from the main prompt if it's duplicated
prompt = prompt.replace(system_prompt, "")

response = litellm.completion(
model=self.model,
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
max_completion_tokens=completion_tokens,
max_completion_tokens=max_new_tokens,
logprobs=return_logits if self.provider == "openai" else None,
stop=stop_sequence,
base_url=self.base_url,
Expand All @@ -123,14 +137,23 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
caching=True,
)
return response
except litellm.BadRequestError as e:
if "message" in e.__dict__:
error_string = (
"The response was filtered due to the prompt triggering Microsoft's content management policy"
)
if error_string in e.__dict__["message"]:
logger.warning(f"{error_string}. Returning empty response.")
return ModelResponse()
except Exception as e:
wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s
logger.warning(
f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}"
)
time.sleep(wait_time)

logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, skipping entry.")
logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, returning empty response.")
return ModelResponse()

def __call_api_parallel(
self,
Expand Down Expand Up @@ -220,7 +243,8 @@ def greedy_until(
result: list[str] = [choice.message.content for choice in response.choices]

cur_response = GenerativeResponse(
result=result,
# In empty responses, the model should return an empty string instead of None
result=result if result[0] else [""],
logits=None,
generated_tokens=[],
input_tokens=[],
Expand Down

0 comments on commit dd12702

Please sign in to comment.