diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 81e1d7d30..e98a64aa4 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -28,7 +28,6 @@ from tqdm import tqdm -from lighteval.models.model_output import ModelResponse from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available @@ -196,30 +195,20 @@ def __call_litellm(self, prompts): def __call_api(prompt): for _ in range(self.API_MAX_RETRY): try: - kwargs = { - "model": self.model, - "messages": prompt, - "response_format": {"type": "text"}, - "max_tokens": 512, - "n": 1, - "caching": True, - } - response = litellm.completion(**kwargs) + response = litellm.completion( + model=self.model, + messages=prompt, + response_format={"type": "text"}, + max_tokens=512, + n=1, + caching=True, + ) text = response.choices[0].message.content - if not text or response.failed: - kwargs["caching"] = False - response = litellm.completion(**kwargs) - text = response.choices[0].message.content - if not text or response.failed: - # Just return an error response if the second attempt fails too - return ModelResponse( - text="Failed to get response from the API.", model=self.model, failed=True - ) return text except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) - return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True) + raise Exception("Failed to get response from the API") results = [] with ThreadPoolExecutor(100) as executor: diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py index b485371ca..7d0ba4818 100644 --- a/src/lighteval/models/model_output.py +++ b/src/lighteval/models/model_output.py @@ -33,7 +33,6 @@ class ModelResponse: generated_tokens: list[int] = field(default_factory=list) # model generations truncated_tokens_count: Optional[int] = 0 # How many tokens truncated padded_tokens_count: Optional[int] = 0 # How many tokens of padding - failed: bool = False def get_result_for_eval(self): raise NotImplementedError()