Skip to content

Commit

Permalink
Merge branch 'main' into math_extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
hynky1999 committed Jan 23, 2025
2 parents 1d1c636 + c82143a commit 571937c
Show file tree
Hide file tree
Showing 19 changed files with 502 additions and 54 deletions.
215 changes: 214 additions & 1 deletion community_tasks/arabic_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
"""
import random
import re
from typing import Any, Dict, List, Optional, Union

from lighteval.metrics.metrics import Metrics
from lighteval.metrics.llm_as_judge import JudgeLM
from lighteval.metrics.metrics import Metric, MetricCategory, Metrics
from lighteval.metrics.utils.metric_utils import MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
Expand Down Expand Up @@ -832,6 +835,215 @@ def __init__(
]


class JudgeMetricWrapper(Metric):
"""Wrapper class for LLM-based judge metric implementation."""

def __init__(self, judge: JudgeLM):
"""
Initializes the judge metric wrapper.
Args:
judge (JudgeLM): The LLM judge instance to use for evaluation.
"""
self.judge = judge
self.metric_name = "llm_as_judge"
self.category = MetricCategory.LLM_AS_JUDGE
self.corpus_level_fn = self.aggregate_scores
self.sample_level_fn = self._sample_level_fn
self.higher_is_better = True # Fixed tuple syntax
self.use_case = MetricUseCase.NONE

def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]:
"""
Computes evaluation scores using the judge's evaluate_answer method.
Args:
responses (list[str]): The predicted answers
formatted_docs (list[Doc]): Documents containing questions and gold answers
Returns:
dict[str, float]: Dictionary containing evaluation scores
"""
results = []
for i, doc in enumerate(formatted_docs):
question = doc.query
gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None
answer = responses[i][0].result[0]

score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold)
results.append({self.metric_name: score})

return results

def aggregate_scores(self, scores: list[dict]) -> float:
return sum(scores) / len(scores) if scores else 0.0

def _sample_level_fn(self):
return None


def parse_candidates(candidates: Union[List[str], str]) -> List[str]:
"""
Parses and validates candidate answers from either list or string format.
Args:
candidates: Either a list of candidate answers or a newline-separated string
Returns:
List[str]: List of validated candidate answers
Raises:
ValueError: If candidates cannot be parsed or are empty
"""
try:
if isinstance(candidates, list):
parsed_candidates = [str(c).strip() for c in candidates if c]
else:
parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()]

if not parsed_candidates:
raise ValueError("No valid candidates found after parsing")

return parsed_candidates
except Exception as e:
raise ValueError(f"Failed to parse candidates: {str(e)}")


def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc:
"""
Formats the prompt for Arabic question answering with candidates.
Args:
line: Dictionary containing question and candidate information
task_name: Optional name for the task
Returns:
Doc: Formatted document for evaluation
Raises:
ValueError: If required fields are missing or invalid
"""
try:
# Validates and extracts the question
if not isinstance(line.get("question"), str):
raise ValueError("Question must be a string")
question = line["question"]

# Processes candidate answers
candidates = parse_candidates(line["candidates"])

# Validates gold answer
if "gold_answer" not in line:
raise ValueError("Gold answer is required")
gold_answer = str(line["gold_answer"])

# Constructs the prompt
instruction = "بناءً على السياقات المقترحة التالية، اجب عن السؤال التالي"
query = f"{instruction}\n\nالسؤال:\n{question}\n\nالسياقات المقترحة:\n{', '.join(candidates)}\n"

return Doc(
task_name=task_name or "alrage",
query=query,
instruction=instruction,
choices=[gold_answer], # Gold answer is used as the only valid choice
gold_index=0, # Index of the correct answer in choices
)
except Exception as e:
raise ValueError(f"Failed to create QA prompt: {str(e)}")


def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
Template for the Arabic judge prompt.
System prompt translation:
You are a neutral expert evaluator. Your tasks are:
1. Evaluate the answer's accuracy compared to the correct answer
2. Verify that the answer is supported by the provided context
3. Evaluate the quality and comprehensiveness of the answer
Rate the answer on a scale from 0 to 10.
Args:
question: The question being evaluated
answer: The provided answer
gold: The correct answer
options: Optional list of answer choices
Returns:
List[Dict[str, str]]: Formatted messages for the judge
"""
messages = [
{
"role": "system",
"content": """أنت مقيّم محايد خبير باللغة العربية. يجب عليك:
1. تقييم دقة الإجابة مقارنة بالإجابة الصحيحة
2. التحقق من أن الإجابة مدعومة بالسياق المقدم
3. تقييم جودة وشمولية الإجابة
مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""",
},
{
"role": "user",
"content": f"""السؤال: {question}
الإجابة المقدمة: {answer}
الإجابة الصحيحة: {gold}
أعط تقييماً من 0 إلى 10:
0-2: إجابة خاطئة تماماً
3-4: إجابة جزئية مع أخطاء
5-6: إجابة متوسطة
7-8: إجابة جيدة
9-10: إجابة ممتازة
اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""",
},
]
return messages


def process_judge_response(response) -> float:
"""Process the judge's response to extract the score"""
# If response is a list, extract the content from the user role
if isinstance(response, list):
response_content = " ".join(item["content"] for item in response if item["role"] == "user")
else:
response_content = response # If it's not a list, use it directly

try:
# Extract the score from the response content
score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit()))
return min(max(score / 10.0, 0.0), 1.0)
except (StopIteration, ValueError):
return 0.0


judge = JudgeLM(
model="Qwen/Qwen2.5-72B-Instruct",
templates=judge_template,
process_judge_response=process_judge_response,
judge_backend="vllm",
)

wrapped_judge = JudgeMetricWrapper(judge)

# Task configuration
alrage_qa_task = LightevalTaskConfig(
name="alrage_qa",
prompt_function=qa_prompt_arabic,
suite=["community"],
hf_repo="OALL/ALRAGE",
hf_subset=None,
hf_avail_splits=["train"],
evaluation_splits=["train"],
metric=[wrapped_judge],
trust_dataset=True,
generation_size=200,
stop_sequence=[],
version=0,
)

TASKS_TABLE = (
ARABIC_MMLU_TASKS
+ ARABIC_MMLU_HT_TASKS
Expand All @@ -852,4 +1064,5 @@ def __init__(
+ [hellaswag_okapi_ar_task]
+ [toxigen_ar_task]
+ [sciq_ar_task]
+ [alrage_qa_task]
)
121 changes: 121 additions & 0 deletions community_tasks/french_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# ruff: noqa: F405, F403, F401
"""
Custom evaluation tasks for lighteval.
This file generally creates just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
This module implements tasks for the french specific datasets
See : https://huggingface.co/fr-gouv-coordination-ia
"""

import random

import numpy as np
from aenum import extend_enum

import lighteval.tasks.extended.ifeval.instructions_registry as instructions_registry
from lighteval.metrics.metrics import Metrics, SampleLevelMetric
from lighteval.metrics.utils.metric_utils import (
MetricCategory,
MetricUseCase,
SampleLevelMetricGrouping,
)
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.extended.ifeval.main import ifeval_metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc


# Ifeval-fr prompt function
def prompt_ifeval_fr(line, task_name: str = None):
return Doc(
task_name=task_name,
query=line["prompt"],
choices=[""],
gold_index=0,
instruction="",
specific={"instructions_id_list": line["instruction_id_list"], "kwargs": line["kwargs"]},
)


# qpqa-fr prompt function
def prompt_gpqa_fr(line, task_name: str = None):
gold_index = random.randint(0, 3)
choices = [line["Réponse incorrecte 1"], line["Réponse incorrecte 2"], line["Réponse incorrecte 3"]]
choices.insert(gold_index, line["Réponse correcte"])

instruction = "Choisissez la réponse correcte aux questions suivantes.\n\n"

query = f"Question: {line['Question']}\n"
query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)])
query += "Answer: "
return Doc(
task_name=task_name,
query=f"{instruction}{query}",
choices=LETTER_INDICES[: len(choices)],
gold_index=gold_index,
instruction=instruction,
)


# IFEVal-fr task


ifeval_fr_task = LightevalTaskConfig(
name="ifeval-fr",
prompt_function=prompt_ifeval_fr, # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
suite=["community"],
hf_repo="fr-gouv-coordination-ia/IFEval-fr",
hf_subset="default",
metric=[ifeval_metrics],
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split="train",
few_shots_select="random_sampling",
generation_size=1280,
stop_sequence=[], # no stop sequence, will use eot token
version="0.1", # select your metric in Metrics
)

# GPQA-fr task
gpqa_fr_task = LightevalTaskConfig(
name="gpqa-fr",
suite=["community"],
prompt_function=prompt_gpqa_fr,
hf_repo="fr-gouv-coordination-ia/gpqa-fr",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select="random_sampling",
generation_size=1,
metric=[Metrics.loglikelihood_acc],
stop_sequence=["\n"],
trust_dataset=True,
version=0,
)

# STORE YOUR EVALS
TASKS_TABLE = [ifeval_fr_task, gpqa_fr_task]
1 change: 1 addition & 0 deletions examples/tasks/OALL_v2_tasks.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ community|arabic_mmlu_ht:sociology|0|0
community|arabic_mmlu_ht:us_foreign_policy|0|0
community|arabic_mmlu_ht:virology|0|0
community|arabic_mmlu_ht:world_religions|0|0
community|alrage_qa|0|0
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dependencies = [
"transformers>=4.38.0",
"accelerate",
"huggingface_hub>=0.23.0",
"torch>=2.0,<2.5",
"torch>=2.0,<3.0",
"GitPython>=3.1.41", # for logging
"datasets>=2.14.0",
"numpy<2", # pinned to avoid incompatibilities
Expand Down
Loading

0 comments on commit 571937c

Please sign in to comment.