diff --git a/community_tasks/arabic_evals.py b/community_tasks/arabic_evals.py
index 86ab69e28..a68abbe6a 100644
--- a/community_tasks/arabic_evals.py
+++ b/community_tasks/arabic_evals.py
@@ -28,8 +28,11 @@
"""
import random
import re
+from typing import Any, Dict, List, Optional, Union
-from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.llm_as_judge import JudgeLM
+from lighteval.metrics.metrics import Metric, MetricCategory, Metrics
+from lighteval.metrics.utils.metric_utils import MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
@@ -832,6 +835,215 @@ def __init__(
]
+class JudgeMetricWrapper(Metric):
+ """Wrapper class for LLM-based judge metric implementation."""
+
+ def __init__(self, judge: JudgeLM):
+ """
+ Initializes the judge metric wrapper.
+
+ Args:
+ judge (JudgeLM): The LLM judge instance to use for evaluation.
+ """
+ self.judge = judge
+ self.metric_name = "llm_as_judge"
+ self.category = MetricCategory.LLM_AS_JUDGE
+ self.corpus_level_fn = self.aggregate_scores
+ self.sample_level_fn = self._sample_level_fn
+ self.higher_is_better = True # Fixed tuple syntax
+ self.use_case = MetricUseCase.NONE
+
+ def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]:
+ """
+ Computes evaluation scores using the judge's evaluate_answer method.
+
+ Args:
+ responses (list[str]): The predicted answers
+ formatted_docs (list[Doc]): Documents containing questions and gold answers
+
+ Returns:
+ dict[str, float]: Dictionary containing evaluation scores
+ """
+ results = []
+ for i, doc in enumerate(formatted_docs):
+ question = doc.query
+ gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None
+ answer = responses[i][0].result[0]
+
+ score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold)
+ results.append({self.metric_name: score})
+
+ return results
+
+ def aggregate_scores(self, scores: list[dict]) -> float:
+ return sum(scores) / len(scores) if scores else 0.0
+
+ def _sample_level_fn(self):
+ return None
+
+
+def parse_candidates(candidates: Union[List[str], str]) -> List[str]:
+ """
+ Parses and validates candidate answers from either list or string format.
+
+ Args:
+ candidates: Either a list of candidate answers or a newline-separated string
+
+ Returns:
+ List[str]: List of validated candidate answers
+
+ Raises:
+ ValueError: If candidates cannot be parsed or are empty
+ """
+ try:
+ if isinstance(candidates, list):
+ parsed_candidates = [str(c).strip() for c in candidates if c]
+ else:
+ parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()]
+
+ if not parsed_candidates:
+ raise ValueError("No valid candidates found after parsing")
+
+ return parsed_candidates
+ except Exception as e:
+ raise ValueError(f"Failed to parse candidates: {str(e)}")
+
+
+def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc:
+ """
+ Formats the prompt for Arabic question answering with candidates.
+
+ Args:
+ line: Dictionary containing question and candidate information
+ task_name: Optional name for the task
+
+ Returns:
+ Doc: Formatted document for evaluation
+
+ Raises:
+ ValueError: If required fields are missing or invalid
+ """
+ try:
+ # Validates and extracts the question
+ if not isinstance(line.get("question"), str):
+ raise ValueError("Question must be a string")
+ question = line["question"]
+
+ # Processes candidate answers
+ candidates = parse_candidates(line["candidates"])
+
+ # Validates gold answer
+ if "gold_answer" not in line:
+ raise ValueError("Gold answer is required")
+ gold_answer = str(line["gold_answer"])
+
+ # Constructs the prompt
+ instruction = "بناءً على السياقات المقترحة التالية، اجب عن السؤال التالي"
+ query = f"{instruction}\n\nالسؤال:\n{question}\n\nالسياقات المقترحة:\n{', '.join(candidates)}\n"
+
+ return Doc(
+ task_name=task_name or "alrage",
+ query=query,
+ instruction=instruction,
+ choices=[gold_answer], # Gold answer is used as the only valid choice
+ gold_index=0, # Index of the correct answer in choices
+ )
+ except Exception as e:
+ raise ValueError(f"Failed to create QA prompt: {str(e)}")
+
+
+def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]:
+ """
+ Template for the Arabic judge prompt.
+
+ System prompt translation:
+ You are a neutral expert evaluator. Your tasks are:
+ 1. Evaluate the answer's accuracy compared to the correct answer
+ 2. Verify that the answer is supported by the provided context
+ 3. Evaluate the quality and comprehensiveness of the answer
+ Rate the answer on a scale from 0 to 10.
+
+ Args:
+ question: The question being evaluated
+ answer: The provided answer
+ gold: The correct answer
+ options: Optional list of answer choices
+
+ Returns:
+ List[Dict[str, str]]: Formatted messages for the judge
+ """
+ messages = [
+ {
+ "role": "system",
+ "content": """أنت مقيّم محايد خبير باللغة العربية. يجب عليك:
+1. تقييم دقة الإجابة مقارنة بالإجابة الصحيحة
+2. التحقق من أن الإجابة مدعومة بالسياق المقدم
+3. تقييم جودة وشمولية الإجابة
+
+مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""",
+ },
+ {
+ "role": "user",
+ "content": f"""السؤال: {question}
+
+الإجابة المقدمة: {answer}
+
+الإجابة الصحيحة: {gold}
+
+أعط تقييماً من 0 إلى 10:
+0-2: إجابة خاطئة تماماً
+3-4: إجابة جزئية مع أخطاء
+5-6: إجابة متوسطة
+7-8: إجابة جيدة
+9-10: إجابة ممتازة
+
+اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""",
+ },
+ ]
+ return messages
+
+
+def process_judge_response(response) -> float:
+ """Process the judge's response to extract the score"""
+ # If response is a list, extract the content from the user role
+ if isinstance(response, list):
+ response_content = " ".join(item["content"] for item in response if item["role"] == "user")
+ else:
+ response_content = response # If it's not a list, use it directly
+
+ try:
+ # Extract the score from the response content
+ score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit()))
+ return min(max(score / 10.0, 0.0), 1.0)
+ except (StopIteration, ValueError):
+ return 0.0
+
+
+judge = JudgeLM(
+ model="Qwen/Qwen2.5-72B-Instruct",
+ templates=judge_template,
+ process_judge_response=process_judge_response,
+ judge_backend="vllm",
+)
+
+wrapped_judge = JudgeMetricWrapper(judge)
+
+# Task configuration
+alrage_qa_task = LightevalTaskConfig(
+ name="alrage_qa",
+ prompt_function=qa_prompt_arabic,
+ suite=["community"],
+ hf_repo="OALL/ALRAGE",
+ hf_subset=None,
+ hf_avail_splits=["train"],
+ evaluation_splits=["train"],
+ metric=[wrapped_judge],
+ trust_dataset=True,
+ generation_size=200,
+ stop_sequence=[],
+ version=0,
+)
+
TASKS_TABLE = (
ARABIC_MMLU_TASKS
+ ARABIC_MMLU_HT_TASKS
@@ -852,4 +1064,5 @@ def __init__(
+ [hellaswag_okapi_ar_task]
+ [toxigen_ar_task]
+ [sciq_ar_task]
+ + [alrage_qa_task]
)
diff --git a/community_tasks/french_evals.py b/community_tasks/french_evals.py
new file mode 100644
index 000000000..7fda8dc7a
--- /dev/null
+++ b/community_tasks/french_evals.py
@@ -0,0 +1,121 @@
+# MIT License
+
+# Copyright (c) 2024 The HuggingFace Team
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# ruff: noqa: F405, F403, F401
+"""
+Custom evaluation tasks for lighteval.
+
+This file generally creates just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
+
+This module implements tasks for the french specific datasets
+See : https://huggingface.co/fr-gouv-coordination-ia
+"""
+
+import random
+
+import numpy as np
+from aenum import extend_enum
+
+import lighteval.tasks.extended.ifeval.instructions_registry as instructions_registry
+from lighteval.metrics.metrics import Metrics, SampleLevelMetric
+from lighteval.metrics.utils.metric_utils import (
+ MetricCategory,
+ MetricUseCase,
+ SampleLevelMetricGrouping,
+)
+from lighteval.tasks.default_prompts import LETTER_INDICES
+from lighteval.tasks.extended.ifeval.main import ifeval_metrics
+from lighteval.tasks.lighteval_task import LightevalTaskConfig
+from lighteval.tasks.requests import Doc
+
+
+# Ifeval-fr prompt function
+def prompt_ifeval_fr(line, task_name: str = None):
+ return Doc(
+ task_name=task_name,
+ query=line["prompt"],
+ choices=[""],
+ gold_index=0,
+ instruction="",
+ specific={"instructions_id_list": line["instruction_id_list"], "kwargs": line["kwargs"]},
+ )
+
+
+# qpqa-fr prompt function
+def prompt_gpqa_fr(line, task_name: str = None):
+ gold_index = random.randint(0, 3)
+ choices = [line["Réponse incorrecte 1"], line["Réponse incorrecte 2"], line["Réponse incorrecte 3"]]
+ choices.insert(gold_index, line["Réponse correcte"])
+
+ instruction = "Choisissez la réponse correcte aux questions suivantes.\n\n"
+
+ query = f"Question: {line['Question']}\n"
+ query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)])
+ query += "Answer: "
+ return Doc(
+ task_name=task_name,
+ query=f"{instruction}{query}",
+ choices=LETTER_INDICES[: len(choices)],
+ gold_index=gold_index,
+ instruction=instruction,
+ )
+
+
+# IFEVal-fr task
+
+
+ifeval_fr_task = LightevalTaskConfig(
+ name="ifeval-fr",
+ prompt_function=prompt_ifeval_fr, # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
+ suite=["community"],
+ hf_repo="fr-gouv-coordination-ia/IFEval-fr",
+ hf_subset="default",
+ metric=[ifeval_metrics],
+ hf_avail_splits=["train"],
+ evaluation_splits=["train"],
+ few_shots_split="train",
+ few_shots_select="random_sampling",
+ generation_size=1280,
+ stop_sequence=[], # no stop sequence, will use eot token
+ version="0.1", # select your metric in Metrics
+)
+
+# GPQA-fr task
+gpqa_fr_task = LightevalTaskConfig(
+ name="gpqa-fr",
+ suite=["community"],
+ prompt_function=prompt_gpqa_fr,
+ hf_repo="fr-gouv-coordination-ia/gpqa-fr",
+ hf_subset="default",
+ hf_avail_splits=["train"],
+ evaluation_splits=["train"],
+ few_shots_split=None,
+ few_shots_select="random_sampling",
+ generation_size=1,
+ metric=[Metrics.loglikelihood_acc],
+ stop_sequence=["\n"],
+ trust_dataset=True,
+ version=0,
+)
+
+# STORE YOUR EVALS
+TASKS_TABLE = [ifeval_fr_task, gpqa_fr_task]
diff --git a/examples/tasks/OALL_v2_tasks.txt b/examples/tasks/OALL_v2_tasks.txt
index fc1b4f7e9..176b662d7 100644
--- a/examples/tasks/OALL_v2_tasks.txt
+++ b/examples/tasks/OALL_v2_tasks.txt
@@ -115,3 +115,4 @@ community|arabic_mmlu_ht:sociology|0|0
community|arabic_mmlu_ht:us_foreign_policy|0|0
community|arabic_mmlu_ht:virology|0|0
community|arabic_mmlu_ht:world_religions|0|0
+community|alrage_qa|0|0
diff --git a/pyproject.toml b/pyproject.toml
index f407895a0..126d66244 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,7 +57,7 @@ dependencies = [
"transformers>=4.38.0",
"accelerate",
"huggingface_hub>=0.23.0",
- "torch>=2.0,<2.5",
+ "torch>=2.0,<3.0",
"GitPython>=3.1.41", # for logging
"datasets>=2.14.0",
"numpy<2", # pinned to avoid incompatibilities
diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py
index 6cad9189f..af5784bc2 100644
--- a/src/lighteval/logging/evaluation_tracker.py
+++ b/src/lighteval/logging/evaluation_tracker.py
@@ -149,6 +149,28 @@ def __init__(
self.public = public
+ @property
+ def results(self):
+ config_general = asdict(self.general_config_logger)
+ # We remove the config from logging, which contains context/accelerator objects
+ config_general.pop("config")
+ results = {
+ "config_general": config_general,
+ "results": self.metrics_logger.metric_aggregated,
+ "versions": self.versions_logger.versions,
+ "config_tasks": self.task_config_logger.tasks_configs,
+ "summary_tasks": self.details_logger.compiled_details,
+ "summary_general": asdict(self.details_logger.compiled_details_over_all_tasks),
+ }
+ return results
+
+ @property
+ def details(self):
+ return {
+ task_name: [asdict(detail) for detail in task_details]
+ for task_name, task_details in self.details_logger.details.items()
+ }
+
def save(self) -> None:
"""Saves the experiment information and results to files, and to the hub if requested."""
logger.info("Saving experiment tracker")
@@ -281,6 +303,31 @@ def push_to_hub(
self.recreate_metadata_card(repo_id)
+ def push_results_to_hub(self, repo_id: str, path_in_repo: str, private: bool | None = None):
+ repo_id = repo_id if "/" in repo_id else f"{self.hub_results_org}/{repo_id}"
+ private = private if private is not None else not self.public
+ self.api.create_repo(repo_id, private=private, repo_type="dataset", exist_ok=True)
+ results_json = json.dumps(self.results, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)
+ self.api.upload_file(
+ repo_id=repo_id,
+ path_or_fileobj=results_json.encode(),
+ path_in_repo=path_in_repo,
+ repo_type="dataset",
+ )
+
+ def push_details_to_hub(self, repo_id: str, path_in_repo: str, private: bool | None = None):
+ repo_id = repo_id if "/" in repo_id else f"{self.hub_results_org}/{repo_id}"
+ private = private if private is not None else not self.public
+ self.api.create_repo(repo_id, private=private, repo_type="dataset", exist_ok=True)
+ for task_name, details in self.details:
+ details_json = "\n".join([json.dumps(detail) for detail in details])
+ self.api.upload_file(
+ repo_id=repo_id,
+ path_or_fileobj=details_json.encode(),
+ path_in_repo=path_in_repo.format(task_name=task_name),
+ repo_type="dataset",
+ )
+
def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901
"""Fully updates the details repository metadata card for the currently evaluated model
diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py
index 81e1d7d30..23beda76f 100644
--- a/src/lighteval/metrics/llm_as_judge.py
+++ b/src/lighteval/metrics/llm_as_judge.py
@@ -28,7 +28,6 @@
from tqdm import tqdm
-from lighteval.models.model_output import ModelResponse
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
@@ -194,6 +193,7 @@ def __call_litellm(self, prompts):
import litellm
def __call_api(prompt):
+ error_message = "ERROR: Failed to get response from the API."
for _ in range(self.API_MAX_RETRY):
try:
kwargs = {
@@ -206,20 +206,19 @@ def __call_api(prompt):
}
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
- if not text or response.failed:
+ if not text or text == error_message:
kwargs["caching"] = False
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
- if not text or response.failed:
+ if not text or text == error_message:
# Just return an error response if the second attempt fails too
- return ModelResponse(
- text="Failed to get response from the API.", model=self.model, failed=True
- )
+ logger.error(f"Failed to get response from the API for prompt: {prompt}")
+ return error_message
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
- return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True)
+ return error_message
results = []
with ThreadPoolExecutor(100) as executor:
diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py
index 03b1b2c5e..3c2de418f 100644
--- a/src/lighteval/metrics/metrics_corpus.py
+++ b/src/lighteval/metrics/metrics_corpus.py
@@ -26,6 +26,7 @@
"""
import logging
import math
+from typing import Literal
import numpy as np
import sacrebleu
@@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):
class CorpusLevelTranslationMetric:
- def __init__(self, metric_type: str):
+ def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
"""Stores the relevant parameters for a corpus level translation metric.
Args:
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
"""
- if metric_type == "bleu":
- self.metric = sacrebleu.corpus_bleu
- elif metric_type == "chrf":
- self.metric = sacrebleu.corpus_chrf
- elif metric_type == "ter":
- self.metric = sacrebleu.corpus_ter
+ self.metric_type = metric_type
+ self.lang = lang
+
+ def get_metric(self):
+ if self.metric_type == "bleu":
+ return sacrebleu.BLEU(trg_lang=self.lang)
+ elif self.metric_type == "chrf":
+ return sacrebleu.CHRF()
+ elif self.metric_type == "ter":
+ return sacrebleu.TER(asian_support=True if self.lang != "" else False)
else:
- raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
+ raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")
def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
+ metric = self.get_metric()
golds = [i.golds for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
logger.info(
- f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
+ f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
preds.append(pred[0])
- return float(self.metric(hypotheses=preds, references=golds).score)
+ return float(metric.corpus_score(hypotheses=preds, references=golds).score)
class CorpusLevelPerplexityMetric:
diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py
index 9e29f569d..195221aa5 100644
--- a/src/lighteval/models/litellm_model.py
+++ b/src/lighteval/models/litellm_model.py
@@ -91,7 +91,7 @@ def __init__(self, config, env_config) -> None:
self._tokenizer = encode
self.pairwise_tokenization = False
litellm.drop_params = True
- litellm.verbose = True
+ litellm.set_verbose = False
def _prepare_stop_sequence(self, stop_sequence):
"""Prepare and validate stop sequence."""
@@ -130,13 +130,16 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
"messages": prompt,
"max_completion_tokens": max_new_tokens,
"logprobs": return_logits if self.provider == "openai" else None,
- "stop": stop_sequence,
"base_url": self.base_url,
"n": num_samples,
- "temperature": self.TEMPERATURE,
- "top_p": self.TOP_P,
"caching": True,
}
+ if "o1" in self.model:
+ logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
+ else:
+ kwargs["temperature"] = self.TEMPERATURE
+ kwargs["top_p"] = self.TOP_P
+ kwargs["stop"] = stop_sequence
response = litellm.completion(**kwargs)
diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py
index b485371ca..7d0ba4818 100644
--- a/src/lighteval/models/model_output.py
+++ b/src/lighteval/models/model_output.py
@@ -33,7 +33,6 @@ class ModelResponse:
generated_tokens: list[int] = field(default_factory=list) # model generations
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
padded_tokens_count: Optional[int] = 0 # How many tokens of padding
- failed: bool = False
def get_result_for_eval(self):
raise NotImplementedError()
diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py
index c5395281c..66c3d53b4 100644
--- a/src/lighteval/tasks/default_prompts.py
+++ b/src/lighteval/tasks/default_prompts.py
@@ -120,7 +120,7 @@ def asdiv(line, task_name: str = None):
def babi_qa(line, task_name: str = None): # HELM
def process_path(path: str) -> str:
- """Turn a path string (task 19) from the original format 's,w' to a verbal model-friendly format 'south west'"""
+ """Turn a path string (task 19) from the original format 's,w' into a verbal model-friendly format 'south west'"""
steps = path.split(",")
directions = {"s": "south", "n": "north", "e": "east", "w": "west"}
path = " ".join([directions[step] for step in steps])
@@ -281,7 +281,7 @@ def bbh_logical_deduction_three_objects(line, task_name: str = None):
def bbh_movie_recommendation(line, task_name: str = None):
if line["target"] == "Monsters, Inc": # this line is not correctly formatted
logger.warning(
- "One sample removed from task bbh:movie_recommentation because its line is incorrectly formatted."
+ "One sample removed from task bbh:movie_recommendation because its line is incorrectly formatted."
)
return []
instruction = "Recommend movies similar to the given list of movies.\n\n"
@@ -500,7 +500,7 @@ def civil_comments(line, task_name: str = None):
def cnn_dm(line, task_name: str = None):
return Doc(
task_name=task_name,
- query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 3 sentence.\n",
+ query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 3 sentences.\n",
choices=[str(line["summary"])],
gold_index=0,
specific={"text": line["article"]},
@@ -730,7 +730,7 @@ def gpqa(line, task_name: str = None):
def gsm8k(line, task_name: str = None):
- # Has special analysis in metric for number decomposiition
+ # Has special analysis in metric for number decomposition
return Doc(
task_name=task_name,
query=f"Question: {line['question']}\nAnswer:",
@@ -2076,7 +2076,7 @@ def rte(line, task_name: str = None):
return Doc(
task_name=task_name,
query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:",
- choices=[" True", " False"], # 0 = entailement, 1 = not entailment
+ choices=[" True", " False"], # 0 = entailment, 1 = not entailment
gold_index=int(line["label"]),
# "metric": "choices_loglikelihood",
)
diff --git a/src/lighteval/tasks/extended/mix_eval/main.py b/src/lighteval/tasks/extended/mix_eval/main.py
index 8684e910c..eaa58f2a5 100644
--- a/src/lighteval/tasks/extended/mix_eval/main.py
+++ b/src/lighteval/tasks/extended/mix_eval/main.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import logging
import re
import numpy as np
@@ -37,6 +38,9 @@
from lighteval.tasks.requests import Doc
+logger = logging.getLogger(__name__)
+
+
def mixeval_freeform_prompt(line, task_name: str = ""):
prompt = construct_prompt_freeform(line)
return Doc(
@@ -71,19 +75,30 @@ def mixeval_multichoice_prompt(line, task_name: str = ""):
def process_judge_response(x):
- search = re.search(r"\s(\d)\s", x)
- return int(search.group(1)) if search else 0
+ try:
+ search = re.search(r"\s(\d)\s", x)
+ return int(search.group(1)) if search else 0
+ except Exception as e:
+ logger.warning(f"Error processing judge response for flow: {e}")
+ return 0
def process_judge_response_multichoice_gpt(x):
- search = re.search(r"\[\[([01])\]\]", x)
- return int(search.group(1)) if search else 0
+ try:
+ search = re.search(r"\[\[([01])\]\]", x)
+ return int(search.group(1)) if search else 0
+ except Exception as e:
+ logger.warning(f"Error processing judge response for multichoice GPT: {e}")
+ return 0
def process_judge_response_freeform_gpt(x):
- search = re.search(r"\[\[(\d.\d)\]\]", x)
- answer = float(search.group(1) if search else 0)
- return answer
+ try:
+ search = re.search(r"\[\[(\d.\d)\]\]", x)
+ return float(search.group(1)) if search else 0
+ except Exception as e:
+ logger.warning(f"Error processing judge response for freeform GPT: {e}")
+ return 0
llm_judge_mixeval_multichoice_flow_judge = SampleLevelMetricGrouping(
diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py
index 09886e4db..c187f653f 100644
--- a/src/lighteval/tasks/lighteval_task.py
+++ b/src/lighteval/tasks/lighteval_task.py
@@ -621,7 +621,7 @@ def create_requests_from_tasks( # noqa: C901
n_samples = min(max_samples, len(task_docs)) if max_samples else len(task_docs)
evaluation_tracker.task_config_logger.log_num_docs(task_name, len(task_docs), n_samples)
- # logs out the diferent versions of the tasks for every few shot
+ # logs out the different versions of the tasks for every few shot
for num_fewshot, _ in fewshot_dict[task_name]:
cur_task_name = f"{task_name}|{num_fewshot}"
evaluation_tracker.versions_logger.log(cur_task_name, task.version)
@@ -633,7 +633,7 @@ def create_requests_from_tasks( # noqa: C901
prompt_manager = PromptManager(lm=lm, task=task)
seeds = prompt_manager.few_shot_sampler.get_fewshot_seeds(num_fewshot_seeds)
- # We can do several round of fewshots sampling to get some variance informations
+ # We can do several round of fewshots sampling to get some variance information
for seed in seeds:
for doc_id in range(n_samples):
doc_id_seed = f"{doc_id}_{seed}" # if we do several rounds of few shot sampling we have several seeds
diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py
index af55c3184..6d066c921 100644
--- a/src/lighteval/tasks/prompt_manager.py
+++ b/src/lighteval/tasks/prompt_manager.py
@@ -132,7 +132,7 @@ def _multi_turn_contexts(self, doc: Doc, use_chat_template: bool, system_prompt:
Multi turn tasks need use chat templating.
Args:
- doc (Doc): Formated document.
+ doc (Doc): Formatted document.
use_chat_template (bool): wether or not to use chat template. Will fail if false.
system_prompt (Optional[str]): The system prompt to use
tokenizer (PreTrainedTokenizer): The tokenizer used for the chat template
diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py
index 834e81706..174a98d33 100644
--- a/src/lighteval/tasks/registry.py
+++ b/src/lighteval/tasks/registry.py
@@ -166,7 +166,7 @@ def _task_superset_dict(self):
"lighteval|mmlu" -> ["lighteval|mmlu:abstract_algebra", "lighteval|mmlu:college_biology", ...]
}
"""
- # Note: sorted before groupby is imporant as the python implementation of groupby does not
+ # Note: sorted before groupby is important as the python implementation of groupby does not
# behave like sql groupby. For more info see the docs of itertools.groupby
superset_dict = {k: list(v) for k, v in groupby(sorted(self.task_registry.keys()), lambda x: x.split(":")[0])}
# Only consider supersets with more than one task
diff --git a/src/lighteval/tasks/templates/continuation.py b/src/lighteval/tasks/templates/continuation.py
index 6435fc8f2..c9cd5d1bc 100644
--- a/src/lighteval/tasks/templates/continuation.py
+++ b/src/lighteval/tasks/templates/continuation.py
@@ -112,7 +112,7 @@ def get_continuation_prompt_function(
C. Continuation 3
Answer: A/B/C
- This template is very similar to the `Multiple Choice` template, except that it only takes context/continuations as input and don't use the anchor labels (Question/Answer)
+ This template is very similar to the `Multiple Choice` template, except that it only takes context/continuations as input and doesn't use the anchor labels (Question/Answer)
Args:
language (Language): The language of the Continuation task.
diff --git a/src/lighteval/tasks/templates/copa.py b/src/lighteval/tasks/templates/copa.py
index 2129332f8..a4d82c4de 100644
--- a/src/lighteval/tasks/templates/copa.py
+++ b/src/lighteval/tasks/templates/copa.py
@@ -86,17 +86,17 @@ def get_copa_prompt_function(
Format:
*CF*
- Context Premise thefore/cause | (Continuation 1, Continuation 2, Continuation 3)
+ Context Premise therefore/cause | (Continuation 1, Continuation 2, Continuation 3)
*Hybrid*
- Context Premise thefore/cause
+ Context Premise therefore/cause
A. Continuation 1
B. Continuation 2
C. Continuation 3
Answer: | Continuation 1/Continuation 2/Continuation 3
*MCF*
- Context Premise thefore/cause
+ Context Premise therefore/cause
A. Continuation 1
B. Continuation 2
C. Continuation 3
diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py
index 43a3061b6..f5c4ba3db 100644
--- a/src/lighteval/tasks/templates/hellaswag.py
+++ b/src/lighteval/tasks/templates/hellaswag.py
@@ -70,7 +70,7 @@ def get_hellaswag_prompt_function(
Create a templated prompt function for a Hellaswag task.
Format:
- Context Premise thefore/cause | (Continuation 1, Continuation 2, Continuation 3)
+ Context Premise therefore/cause | (Continuation 1, Continuation 2, Continuation 3)
Args:
language (Language): The language of the Hellaswag task.
@@ -126,7 +126,7 @@ def hellaswag_prompt(
if ctx_b:
ctx_a = join_ctxs(ctx_a, ctx_b)
- # Removoal of the [header] can happen and we need the first letter to be capital afterwards
+ # Removal of the [header] can happen and we need the first letter to be capital afterwards
full_context = HELLASWAG_QUERY.format(activity_label=activity_label, ctx=ctx_a)
choices = [
hellaswag_preprocess(
diff --git a/src/lighteval/tasks/templates/nli.py b/src/lighteval/tasks/templates/nli.py
index 842460306..e8809e17b 100644
--- a/src/lighteval/tasks/templates/nli.py
+++ b/src/lighteval/tasks/templates/nli.py
@@ -228,7 +228,7 @@ def prompt_fn(line: dict, task_name: str):
if input_data is None:
return None
- # Template based on dicussion here: https://github.com/EleutherAI/lm-evaluation-harness/issues/450
+ # Template based on discussion here: https://github.com/EleutherAI/lm-evaluation-harness/issues/450
labels = [capitalize(get_relation_label(label, translation_literals)) for label in relations]
premise, hypothesis, gold_idx = input_data["premise"], input_data["hypothesis"], input_data["gold_idx"]
@@ -236,15 +236,15 @@ def prompt_fn(line: dict, task_name: str):
hypothesis = input_data["hypothesis"]
if isinstance(formulation, HybridFormulation):
# If we have the neither option move it to the end to be consistent with standard NLI evaluation
- rearanged_labales = labels
+ rearranged_labels = labels
if "neutral" in relations:
neutral_idx = relations.index("neutral")
- rearanged_labales = labels[:neutral_idx] + labels[neutral_idx + 1 :] + [labels[neutral_idx]]
+ rearranged_labels = labels[:neutral_idx] + labels[neutral_idx + 1 :] + [labels[neutral_idx]]
- choices_str = f"{translation_literals.comma}{translation_literals.word_space}".join(rearanged_labales[:-1])
- hypothesis = f"{hypothesis.rstrip(PUNCT)}{translation_literals.sentence_space}{choices_str}{translation_literals.word_space}{translation_literals.or_word}{translation_literals.word_space}{rearanged_labales[-1]}{translation_literals.question_mark}"
+ choices_str = f"{translation_literals.comma}{translation_literals.word_space}".join(rearranged_labels[:-1])
+ hypothesis = f"{hypothesis.rstrip(PUNCT)}{translation_literals.sentence_space}{choices_str}{translation_literals.word_space}{translation_literals.or_word}{translation_literals.word_space}{rearranged_labels[-1]}{translation_literals.question_mark}"
- # (hynky1999): Ideally we would not compute logprobs of the Yes/No/Also in CF fomulation. However as of right now lighteval doesn't allow to
+ # (hynky1999): Ideally we would not compute logprobs of the Yes/No/Also in CF formulation. However as of right now lighteval doesn't allow to
# use multi-context.
row = {
"instruction": input_data.get("instruction", ""),
diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py
index 51daf7198..756285e62 100644
--- a/src/lighteval/tasks/templates/utils/translation_literals.py
+++ b/src/lighteval/tasks/templates/utils/translation_literals.py
@@ -178,7 +178,29 @@ def __getattribute__(self, name: str) -> str:
Language.BRETON: TranslationLiterals(language=Language.BRETON),
Language.BULGARIAN: TranslationLiterals(language=Language.BULGARIAN),
Language.BURMESE: TranslationLiterals(language=Language.BURMESE),
- Language.CATALAN: TranslationLiterals(language=Language.CATALAN),
+ Language.CATALAN: TranslationLiterals(
+ language=Language.CATALAN,
+ question_word="pregunta",
+ answer="resposta",
+ confirmation_word="cert",
+ yes="sí",
+ no="no",
+ also="també",
+ cause_word="perquè",
+ effect_word="per tant",
+ or_word="o",
+ true="veritable",
+ false="fals",
+ neither="cap",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.CEBUANO: TranslationLiterals(language=Language.CEBUANO),
Language.CHINESE: TranslationLiterals(
language=Language.CHINESE,
@@ -348,7 +370,29 @@ def __getattribute__(self, name: str) -> str:
sentence_space=" ",
colon=":",
),
- Language.GALICIAN: TranslationLiterals(language=Language.GALICIAN),
+ Language.GALICIAN: TranslationLiterals(
+ language=Language.GALICIAN,
+ question_word="pregunta",
+ answer="resposta",
+ confirmation_word="certo",
+ yes="si",
+ no="non",
+ also="tamén",
+ cause_word="porque",
+ effect_word="polo tanto",
+ or_word="ou",
+ true="verdadeiro",
+ false="falso",
+ neither="ningún",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.GEORGIAN: TranslationLiterals(language=Language.GEORGIAN),
Language.GERMAN: TranslationLiterals(
language=Language.GERMAN,