Skip to content

Commit

Permalink
Merge branch 'main' into Document-Custom-Model-Files
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Jan 23, 2025
2 parents 4c33274 + 620873b commit 5815119
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 25 deletions.
215 changes: 214 additions & 1 deletion community_tasks/arabic_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
"""
import random
import re
from typing import Any, Dict, List, Optional, Union

from lighteval.metrics.metrics import Metrics
from lighteval.metrics.llm_as_judge import JudgeLM
from lighteval.metrics.metrics import Metric, MetricCategory, Metrics
from lighteval.metrics.utils.metric_utils import MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
Expand Down Expand Up @@ -832,6 +835,215 @@ def __init__(
]


class JudgeMetricWrapper(Metric):
"""Wrapper class for LLM-based judge metric implementation."""

def __init__(self, judge: JudgeLM):
"""
Initializes the judge metric wrapper.
Args:
judge (JudgeLM): The LLM judge instance to use for evaluation.
"""
self.judge = judge
self.metric_name = "llm_as_judge"
self.category = MetricCategory.LLM_AS_JUDGE
self.corpus_level_fn = self.aggregate_scores
self.sample_level_fn = self._sample_level_fn
self.higher_is_better = True # Fixed tuple syntax
self.use_case = MetricUseCase.NONE

def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]:
"""
Computes evaluation scores using the judge's evaluate_answer method.
Args:
responses (list[str]): The predicted answers
formatted_docs (list[Doc]): Documents containing questions and gold answers
Returns:
dict[str, float]: Dictionary containing evaluation scores
"""
results = []
for i, doc in enumerate(formatted_docs):
question = doc.query
gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None
answer = responses[i][0].result[0]

score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold)
results.append({self.metric_name: score})

return results

def aggregate_scores(self, scores: list[dict]) -> float:
return sum(scores) / len(scores) if scores else 0.0

def _sample_level_fn(self):
return None


def parse_candidates(candidates: Union[List[str], str]) -> List[str]:
"""
Parses and validates candidate answers from either list or string format.
Args:
candidates: Either a list of candidate answers or a newline-separated string
Returns:
List[str]: List of validated candidate answers
Raises:
ValueError: If candidates cannot be parsed or are empty
"""
try:
if isinstance(candidates, list):
parsed_candidates = [str(c).strip() for c in candidates if c]
else:
parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()]

if not parsed_candidates:
raise ValueError("No valid candidates found after parsing")

return parsed_candidates
except Exception as e:
raise ValueError(f"Failed to parse candidates: {str(e)}")


def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc:
"""
Formats the prompt for Arabic question answering with candidates.
Args:
line: Dictionary containing question and candidate information
task_name: Optional name for the task
Returns:
Doc: Formatted document for evaluation
Raises:
ValueError: If required fields are missing or invalid
"""
try:
# Validates and extracts the question
if not isinstance(line.get("question"), str):
raise ValueError("Question must be a string")
question = line["question"]

# Processes candidate answers
candidates = parse_candidates(line["candidates"])

# Validates gold answer
if "gold_answer" not in line:
raise ValueError("Gold answer is required")
gold_answer = str(line["gold_answer"])

# Constructs the prompt
instruction = "بناءً على السياقات المقترحة التالية، اجب عن السؤال التالي"
query = f"{instruction}\n\nالسؤال:\n{question}\n\nالسياقات المقترحة:\n{', '.join(candidates)}\n"

return Doc(
task_name=task_name or "alrage",
query=query,
instruction=instruction,
choices=[gold_answer], # Gold answer is used as the only valid choice
gold_index=0, # Index of the correct answer in choices
)
except Exception as e:
raise ValueError(f"Failed to create QA prompt: {str(e)}")


def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
Template for the Arabic judge prompt.
System prompt translation:
You are a neutral expert evaluator. Your tasks are:
1. Evaluate the answer's accuracy compared to the correct answer
2. Verify that the answer is supported by the provided context
3. Evaluate the quality and comprehensiveness of the answer
Rate the answer on a scale from 0 to 10.
Args:
question: The question being evaluated
answer: The provided answer
gold: The correct answer
options: Optional list of answer choices
Returns:
List[Dict[str, str]]: Formatted messages for the judge
"""
messages = [
{
"role": "system",
"content": """أنت مقيّم محايد خبير باللغة العربية. يجب عليك:
1. تقييم دقة الإجابة مقارنة بالإجابة الصحيحة
2. التحقق من أن الإجابة مدعومة بالسياق المقدم
3. تقييم جودة وشمولية الإجابة
مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""",
},
{
"role": "user",
"content": f"""السؤال: {question}
الإجابة المقدمة: {answer}
الإجابة الصحيحة: {gold}
أعط تقييماً من 0 إلى 10:
0-2: إجابة خاطئة تماماً
3-4: إجابة جزئية مع أخطاء
5-6: إجابة متوسطة
7-8: إجابة جيدة
9-10: إجابة ممتازة
اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""",
},
]
return messages


def process_judge_response(response) -> float:
"""Process the judge's response to extract the score"""
# If response is a list, extract the content from the user role
if isinstance(response, list):
response_content = " ".join(item["content"] for item in response if item["role"] == "user")
else:
response_content = response # If it's not a list, use it directly

try:
# Extract the score from the response content
score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit()))
return min(max(score / 10.0, 0.0), 1.0)
except (StopIteration, ValueError):
return 0.0


judge = JudgeLM(
model="Qwen/Qwen2.5-72B-Instruct",
templates=judge_template,
process_judge_response=process_judge_response,
judge_backend="vllm",
)

wrapped_judge = JudgeMetricWrapper(judge)

# Task configuration
alrage_qa_task = LightevalTaskConfig(
name="alrage_qa",
prompt_function=qa_prompt_arabic,
suite=["community"],
hf_repo="OALL/ALRAGE",
hf_subset=None,
hf_avail_splits=["train"],
evaluation_splits=["train"],
metric=[wrapped_judge],
trust_dataset=True,
generation_size=200,
stop_sequence=[],
version=0,
)

TASKS_TABLE = (
ARABIC_MMLU_TASKS
+ ARABIC_MMLU_HT_TASKS
Expand All @@ -852,4 +1064,5 @@ def __init__(
+ [hellaswag_okapi_ar_task]
+ [toxigen_ar_task]
+ [sciq_ar_task]
+ [alrage_qa_task]
)
1 change: 1 addition & 0 deletions examples/tasks/OALL_v2_tasks.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ community|arabic_mmlu_ht:sociology|0|0
community|arabic_mmlu_ht:us_foreign_policy|0|0
community|arabic_mmlu_ht:virology|0|0
community|arabic_mmlu_ht:world_religions|0|0
community|alrage_qa|0|0
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dependencies = [
"transformers>=4.38.0",
"accelerate",
"huggingface_hub>=0.23.0",
"torch>=2.0,<2.5",
"torch>=2.0,<3.0",
"GitPython>=3.1.41", # for logging
"datasets>=2.14.0",
"numpy<2", # pinned to avoid incompatibilities
Expand Down
10 changes: 5 additions & 5 deletions src/lighteval/tasks/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def asdiv(line, task_name: str = None):

def babi_qa(line, task_name: str = None): # HELM
def process_path(path: str) -> str:
"""Turn a path string (task 19) from the original format 's,w' to a verbal model-friendly format 'south west'"""
"""Turn a path string (task 19) from the original format 's,w' into a verbal model-friendly format 'south west'"""
steps = path.split(",")
directions = {"s": "south", "n": "north", "e": "east", "w": "west"}
path = " ".join([directions[step] for step in steps])
Expand Down Expand Up @@ -281,7 +281,7 @@ def bbh_logical_deduction_three_objects(line, task_name: str = None):
def bbh_movie_recommendation(line, task_name: str = None):
if line["target"] == "Monsters, Inc": # this line is not correctly formatted
logger.warning(
"One sample removed from task bbh:movie_recommentation because its line is incorrectly formatted."
"One sample removed from task bbh:movie_recommendation because its line is incorrectly formatted."
)
return []
instruction = "Recommend movies similar to the given list of movies.\n\n"
Expand Down Expand Up @@ -500,7 +500,7 @@ def civil_comments(line, task_name: str = None):
def cnn_dm(line, task_name: str = None):
return Doc(
task_name=task_name,
query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 3 sentence.\n",
query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 3 sentences.\n",
choices=[str(line["summary"])],
gold_index=0,
specific={"text": line["article"]},
Expand Down Expand Up @@ -730,7 +730,7 @@ def gpqa(line, task_name: str = None):


def gsm8k(line, task_name: str = None):
# Has special analysis in metric for number decomposiition
# Has special analysis in metric for number decomposition
return Doc(
task_name=task_name,
query=f"Question: {line['question']}\nAnswer:",
Expand Down Expand Up @@ -2076,7 +2076,7 @@ def rte(line, task_name: str = None):
return Doc(
task_name=task_name,
query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:",
choices=[" True", " False"], # 0 = entailement, 1 = not entailment
choices=[" True", " False"], # 0 = entailment, 1 = not entailment
gold_index=int(line["label"]),
# "metric": "choices_loglikelihood",
)
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def create_requests_from_tasks( # noqa: C901
n_samples = min(max_samples, len(task_docs)) if max_samples else len(task_docs)
evaluation_tracker.task_config_logger.log_num_docs(task_name, len(task_docs), n_samples)

# logs out the diferent versions of the tasks for every few shot
# logs out the different versions of the tasks for every few shot
for num_fewshot, _ in fewshot_dict[task_name]:
cur_task_name = f"{task_name}|{num_fewshot}"
evaluation_tracker.versions_logger.log(cur_task_name, task.version)
Expand All @@ -633,7 +633,7 @@ def create_requests_from_tasks( # noqa: C901
prompt_manager = PromptManager(lm=lm, task=task)
seeds = prompt_manager.few_shot_sampler.get_fewshot_seeds(num_fewshot_seeds)

# We can do several round of fewshots sampling to get some variance informations
# We can do several round of fewshots sampling to get some variance information
for seed in seeds:
for doc_id in range(n_samples):
doc_id_seed = f"{doc_id}_{seed}" # if we do several rounds of few shot sampling we have several seeds
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _multi_turn_contexts(self, doc: Doc, use_chat_template: bool, system_prompt:
Multi turn tasks need use chat templating.
Args:
doc (Doc): Formated document.
doc (Doc): Formatted document.
use_chat_template (bool): wether or not to use chat template. Will fail if false.
system_prompt (Optional[str]): The system prompt to use
tokenizer (PreTrainedTokenizer): The tokenizer used for the chat template
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _task_superset_dict(self):
"lighteval|mmlu" -> ["lighteval|mmlu:abstract_algebra", "lighteval|mmlu:college_biology", ...]
}
"""
# Note: sorted before groupby is imporant as the python implementation of groupby does not
# Note: sorted before groupby is important as the python implementation of groupby does not
# behave like sql groupby. For more info see the docs of itertools.groupby
superset_dict = {k: list(v) for k, v in groupby(sorted(self.task_registry.keys()), lambda x: x.split(":")[0])}
# Only consider supersets with more than one task
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/templates/continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_continuation_prompt_function(
C. Continuation 3
Answer: A/B/C
This template is very similar to the `Multiple Choice` template, except that it only takes context/continuations as input and don't use the anchor labels (Question/Answer)
This template is very similar to the `Multiple Choice` template, except that it only takes context/continuations as input and doesn't use the anchor labels (Question/Answer)
Args:
language (Language): The language of the Continuation task.
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/tasks/templates/copa.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ def get_copa_prompt_function(
Format:
*CF*
Context Premise thefore/cause | (Continuation 1, Continuation 2, Continuation 3)
Context Premise therefore/cause | (Continuation 1, Continuation 2, Continuation 3)
*Hybrid*
Context Premise thefore/cause
Context Premise therefore/cause
A. Continuation 1
B. Continuation 2
C. Continuation 3
Answer: | Continuation 1/Continuation 2/Continuation 3
*MCF*
Context Premise thefore/cause
Context Premise therefore/cause
A. Continuation 1
B. Continuation 2
C. Continuation 3
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/tasks/templates/hellaswag.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_hellaswag_prompt_function(
Create a templated prompt function for a Hellaswag task.
Format:
Context Premise thefore/cause | (Continuation 1, Continuation 2, Continuation 3)
Context Premise therefore/cause | (Continuation 1, Continuation 2, Continuation 3)
Args:
language (Language): The language of the Hellaswag task.
Expand Down Expand Up @@ -126,7 +126,7 @@ def hellaswag_prompt(
if ctx_b:
ctx_a = join_ctxs(ctx_a, ctx_b)

# Removoal of the [header] can happen and we need the first letter to be capital afterwards
# Removal of the [header] can happen and we need the first letter to be capital afterwards
full_context = HELLASWAG_QUERY.format(activity_label=activity_label, ctx=ctx_a)
choices = [
hellaswag_preprocess(
Expand Down
Loading

0 comments on commit 5815119

Please sign in to comment.