Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Humanity's last exam #520

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/model_configs/serverless_model_with_openai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model:
model_name: "deepseek-ai/DeepSeek-R1" #meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B"
api:
base_url: "https://huggingface.co/api/inference-proxy/together"
api_key: "hf_"
38 changes: 32 additions & 6 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,21 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Literal

from pydantic import BaseModel
from tqdm import tqdm

from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
from lighteval.utils.utils import as_list


logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)


DEFAULT_FORMAT = {"type": "text"}


class JudgeLM:
"""
A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
Expand Down Expand Up @@ -76,6 +81,7 @@ def __init__(
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"],
url: str | None = None,
api_key: str | None = None,
response_format: BaseModel = None,
):
self.model = model
self.template = templates
Expand All @@ -91,6 +97,8 @@ def __init__(
self.api_key = api_key
self.backend = judge_backend

self.response_format = response_format if not None else DEFAULT_FORMAT

def __lazy_load_client(self):
match self.backend:
# Wether we use openai or TGI models, we go through the openai API
Expand Down Expand Up @@ -244,16 +252,34 @@ def __call_api_parallel(self, prompts):
def __call_api(self, prompt):
for _ in range(self.API_MAX_RETRY):
try:
response = self.client.chat.completions.create(
# Base model
response = self.client.beta.chat.completions.parse(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
messages=as_list(prompt),
response_format=self.response_format,
max_tokens=4096,
temperature=0.0,
n=1,
)
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
text = response.choices[0].message.content
return text
answer = response.choices[0].message.parsed
return answer
except TypeError:
try:
# Finetune
response = self.client.chat.completions.create(
model=self.model,
messages=as_list(prompt),
response_format=self.response_format,
max_tokens=512,
n=1,
)
text = response.choices[0].message.content
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)

raise Exception("Failed to get response from the API")
5 changes: 4 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.translate.bleu_score import sentence_bleu
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from lighteval.metrics.imports.bert_scorer import BERTScorer
Expand Down Expand Up @@ -852,7 +853,7 @@ def edit_similarity(self, s1, s2):


class JudgeLLM:
available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"]
available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-4o-2024-08-06"]

def __init__(
self,
Expand All @@ -861,6 +862,7 @@ def __init__(
process_judge_response: Callable,
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"],
short_judge_name: str | None = None,
response_format: BaseModel = None,
) -> None:
match judge_backend:
case "openai":
Expand Down Expand Up @@ -893,6 +895,7 @@ def __init__(
api_key=api_key,
url=url,
judge_backend=judge_backend,
response_format=response_format,
)

def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
Expand Down
22 changes: 16 additions & 6 deletions src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Optional

from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel
Expand Down Expand Up @@ -64,6 +65,8 @@
class OpenAIModelConfig:
model: str
generation_parameters: GenerationParameters = None
base_url: str = "https://api.openai.com/v1"
api_key: str = os.environ.get("OPENAI_API_KEY", None)

def __post_init__(self):
if not self.generation_parameters:
Expand All @@ -74,17 +77,19 @@ def from_path(cls, path: str) -> "OpenAIModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
loaded_file = yaml.safe_load(f)
config = loaded_file["model"]
api = loaded_file.get("api", {})
generation_parameters = GenerationParameters.from_dict(config)
return cls(model=config["model_name"], generation_parameters=generation_parameters)
return cls(model=config["model_name"], generation_parameters=generation_parameters, **api)


class OpenAIClient(LightevalModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, config: OpenAIModelConfig, env_config) -> None:
api_key = os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=api_key)
self.client = OpenAI(api_key=config.api_key, base_url=config.base_url)
self.config = config
self.generation_parameters = config.generation_parameters
self.sampling_params = self.generation_parameters.to_vllm_openai_dict()

Expand All @@ -99,22 +104,27 @@ def __init__(self, config: OpenAIModelConfig, env_config) -> None:
self.API_RETRY_MULTIPLIER = 2
self.CONCURENT_CALLS = 100
self.model = config.model
self._tokenizer = tiktoken.encoding_for_model(self.model)
try:
self._tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
self.pairwise_tokenization = False

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias):
for _ in range(self.API_MAX_RETRY):
try:
response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {}
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "text"},
max_tokens=max_new_tokens if max_new_tokens > 0 else None,
logprobs=return_logits,
logit_bias=logit_bias,
n=num_samples,
**self.sampling_params,
**response_format,
)
self.API_RETRY_SLEEP = 3
return response
except Exception as e:
logger.warning(f"{type(e), e}")
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/tasks/extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@


if can_load_extended_tasks():
import lighteval.tasks.extended.hle.main as hle
import lighteval.tasks.extended.ifeval.main as ifeval
import lighteval.tasks.extended.mix_eval.main as mix_eval
import lighteval.tasks.extended.mt_bench.main as mt_bench
import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks

AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval]
AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, hle]

else:
AVAILABLE_EXTENDED_TASKS_MODULES = []
Loading