Skip to content

Commit

Permalink
Fix formatting and linting issues via pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Manel-Hik committed Jan 13, 2025
1 parent 72553f6 commit d2c15c1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
70 changes: 34 additions & 36 deletions community_tasks/arabic_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@
"""
import random
import re
from typing import Any, Dict, List, Optional, Union

from lighteval.metrics.metrics import Metrics
from lighteval.metrics.llm_as_judge import JudgeLM
from lighteval.metrics.metrics import Metric, MetricCategory, Metrics
from lighteval.metrics.utils.metric_utils import MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from typing import List, Dict, Optional, Union, Any
from lighteval.metrics.metrics import Metric, MetricCategory

from lighteval.metrics.llm_as_judge import JudgeLM
from lighteval.metrics.utils.metric_utils import MetricUseCase

# fmt: off
LETTER_INDICES_AR = ["أ", "ب", "ج", "د", "هـ", "و", "ز", "ح", "ط", "ي", "ك", "ل", "م", "ن", "س", "ع", "ف", "ص", "ق", "ر", "ش", "ت", "ث", "خ", "ذ", "ض", "ظ", "غ"]
Expand Down Expand Up @@ -835,13 +834,14 @@ def __init__(
CustomMadinahQATask(name=f"madinah_qa:{subset}", hf_subset=subset) for subset in MADINAH_QA_SUBSETS
]


class JudgeMetricWrapper(Metric):
"""Wrapper class for LLM-based judge metric implementation."""

def __init__(self, judge: JudgeLM):
"""
Initializes the judge metric wrapper.
Args:
judge (JudgeLM): The LLM judge instance to use for evaluation.
"""
Expand All @@ -856,11 +856,11 @@ def __init__(self, judge: JudgeLM):
def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]:
"""
Computes evaluation scores using the judge's evaluate_answer method.
Args:
responses (list[str]): The predicted answers
formatted_docs (list[Doc]): Documents containing questions and gold answers
Returns:
dict[str, float]: Dictionary containing evaluation scores
"""
Expand All @@ -870,60 +870,56 @@ def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) ->
gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None
answer = responses[i][0].result[0]

score, _, _ = self.judge.evaluate_answer(
question=question,
answer=answer,
options=None,
gold=gold
)
score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold)
results.append({self.metric_name: score})

return results

def aggregate_scores(self, scores: list[dict]) -> float:

return sum(scores) / len(scores) if scores else 0.0

def _sample_level_fn(self):
return None


def parse_candidates(candidates: Union[List[str], str]) -> List[str]:
"""
Parses and validates candidate answers from either list or string format.
Args:
candidates: Either a list of candidate answers or a newline-separated string
Returns:
List[str]: List of validated candidate answers
Raises:
ValueError: If candidates cannot be parsed or are empty
"""
try:
if isinstance(candidates, list):
parsed_candidates = [str(c).strip() for c in candidates if c]
else:
parsed_candidates = [c.strip() for c in str(candidates).split('\n') if c.strip()]
parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()]

if not parsed_candidates:
raise ValueError("No valid candidates found after parsing")

return parsed_candidates
except Exception as e:
raise ValueError(f"Failed to parse candidates: {str(e)}")


def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc:
"""
Formats the prompt for Arabic question answering with candidates.
Args:
line: Dictionary containing question and candidate information
task_name: Optional name for the task
Returns:
Doc: Formatted document for evaluation
Raises:
ValueError: If required fields are missing or invalid
"""
Expand All @@ -950,28 +946,29 @@ def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc:
query=query,
instruction=instruction,
choices=[gold_answer], # Gold answer is used as the only valid choice
gold_index=0 # Index of the correct answer in choices
gold_index=0, # Index of the correct answer in choices
)
except Exception as e:
raise ValueError(f"Failed to create QA prompt: {str(e)}")


def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
Template for the Arabic judge prompt.
System prompt translation:
You are a neutral expert evaluator. Your tasks are:
1. Evaluate the answer's accuracy compared to the correct answer
2. Verify that the answer is supported by the provided context
3. Evaluate the quality and comprehensiveness of the answer
Rate the answer on a scale from 0 to 10.
Args:
question: The question being evaluated
answer: The provided answer
gold: The correct answer
options: Optional list of answer choices
Returns:
List[Dict[str, str]]: Formatted messages for the judge
"""
Expand All @@ -983,7 +980,7 @@ def judge_template(question: str, answer: str, gold: str, options: Optional[List
2. التحقق من أن الإجابة مدعومة بالسياق المقدم
3. تقييم جودة وشمولية الإجابة
مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير."""
مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""",
},
{
"role": "user",
Expand All @@ -1000,8 +997,8 @@ def judge_template(question: str, answer: str, gold: str, options: Optional[List
7-8: إجابة جيدة
9-10: إجابة ممتازة
اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:"""
}
اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""",
},
]
return messages

Expand All @@ -1010,22 +1007,23 @@ def process_judge_response(response) -> float:
"""Process the judge's response to extract the score"""
# If response is a list, extract the content from the user role
if isinstance(response, list):
response_content = ' '.join(item['content'] for item in response if item['role'] == 'user')
response_content = " ".join(item["content"] for item in response if item["role"] == "user")
else:
response_content = response # If it's not a list, use it directly

try:
# Extract the score from the response content
score = float(next(num for num in response_content.split() if num.replace('.', '', 1).isdigit()))
score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit()))
return min(max(score / 10.0, 0.0), 1.0)
except (StopIteration, ValueError):
return 0.0


judge = JudgeLM(
model="Qwen/Qwen2.5-72B-Instruct",
templates=judge_template,
process_judge_response=process_judge_response,
judge_backend="vllm"
judge_backend="vllm",
)

wrapped_judge = JudgeMetricWrapper(judge)
Expand All @@ -1043,7 +1041,7 @@ def process_judge_response(response) -> float:
trust_dataset=True,
generation_size=200,
stop_sequence=[],
version=0
version=0,
)

TASKS_TABLE = (
Expand Down
1 change: 0 additions & 1 deletion examples/tasks/OALL_v2_tasks.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,3 @@ community|arabic_mmlu_ht:us_foreign_policy|0|0
community|arabic_mmlu_ht:virology|0|0
community|arabic_mmlu_ht:world_religions|0|0
community|alrage_qa|0|0

0 comments on commit d2c15c1

Please sign in to comment.