From d2c15c1ae6d87a222684356bd716af8796f09000 Mon Sep 17 00:00:00 2001 From: Manel-Hik Date: Mon, 13 Jan 2025 12:18:32 +0100 Subject: [PATCH] Fix formatting and linting issues via pre-commit hooks --- community_tasks/arabic_evals.py | 70 ++++++++++++++++---------------- examples/tasks/OALL_v2_tasks.txt | 1 - 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/community_tasks/arabic_evals.py b/community_tasks/arabic_evals.py index 8d303767..a68abbe6 100644 --- a/community_tasks/arabic_evals.py +++ b/community_tasks/arabic_evals.py @@ -28,16 +28,15 @@ """ import random import re +from typing import Any, Dict, List, Optional, Union -from lighteval.metrics.metrics import Metrics +from lighteval.metrics.llm_as_judge import JudgeLM +from lighteval.metrics.metrics import Metric, MetricCategory, Metrics +from lighteval.metrics.utils.metric_utils import MetricUseCase from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -from typing import List, Dict, Optional, Union, Any -from lighteval.metrics.metrics import Metric, MetricCategory -from lighteval.metrics.llm_as_judge import JudgeLM -from lighteval.metrics.utils.metric_utils import MetricUseCase # fmt: off LETTER_INDICES_AR = ["أ", "ب", "ج", "د", "هـ", "و", "ز", "ح", "ط", "ي", "ك", "ل", "م", "ن", "س", "ع", "ف", "ص", "ق", "ر", "ش", "ت", "ث", "خ", "ذ", "ض", "ظ", "غ"] @@ -835,13 +834,14 @@ def __init__( CustomMadinahQATask(name=f"madinah_qa:{subset}", hf_subset=subset) for subset in MADINAH_QA_SUBSETS ] + class JudgeMetricWrapper(Metric): """Wrapper class for LLM-based judge metric implementation.""" - + def __init__(self, judge: JudgeLM): """ Initializes the judge metric wrapper. - + Args: judge (JudgeLM): The LLM judge instance to use for evaluation. """ @@ -856,11 +856,11 @@ def __init__(self, judge: JudgeLM): def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]: """ Computes evaluation scores using the judge's evaluate_answer method. - + Args: responses (list[str]): The predicted answers formatted_docs (list[Doc]): Documents containing questions and gold answers - + Returns: dict[str, float]: Dictionary containing evaluation scores """ @@ -870,33 +870,28 @@ def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None answer = responses[i][0].result[0] - score, _, _ = self.judge.evaluate_answer( - question=question, - answer=answer, - options=None, - gold=gold - ) + score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold) results.append({self.metric_name: score}) return results def aggregate_scores(self, scores: list[dict]) -> float: - return sum(scores) / len(scores) if scores else 0.0 def _sample_level_fn(self): return None + def parse_candidates(candidates: Union[List[str], str]) -> List[str]: """ Parses and validates candidate answers from either list or string format. - + Args: candidates: Either a list of candidate answers or a newline-separated string - + Returns: List[str]: List of validated candidate answers - + Raises: ValueError: If candidates cannot be parsed or are empty """ @@ -904,26 +899,27 @@ def parse_candidates(candidates: Union[List[str], str]) -> List[str]: if isinstance(candidates, list): parsed_candidates = [str(c).strip() for c in candidates if c] else: - parsed_candidates = [c.strip() for c in str(candidates).split('\n') if c.strip()] - + parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()] + if not parsed_candidates: raise ValueError("No valid candidates found after parsing") - + return parsed_candidates except Exception as e: raise ValueError(f"Failed to parse candidates: {str(e)}") + def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc: """ Formats the prompt for Arabic question answering with candidates. - + Args: line: Dictionary containing question and candidate information task_name: Optional name for the task - + Returns: Doc: Formatted document for evaluation - + Raises: ValueError: If required fields are missing or invalid """ @@ -950,28 +946,29 @@ def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc: query=query, instruction=instruction, choices=[gold_answer], # Gold answer is used as the only valid choice - gold_index=0 # Index of the correct answer in choices + gold_index=0, # Index of the correct answer in choices ) except Exception as e: raise ValueError(f"Failed to create QA prompt: {str(e)}") + def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]: """ Template for the Arabic judge prompt. - + System prompt translation: You are a neutral expert evaluator. Your tasks are: 1. Evaluate the answer's accuracy compared to the correct answer 2. Verify that the answer is supported by the provided context 3. Evaluate the quality and comprehensiveness of the answer Rate the answer on a scale from 0 to 10. - + Args: question: The question being evaluated answer: The provided answer gold: The correct answer options: Optional list of answer choices - + Returns: List[Dict[str, str]]: Formatted messages for the judge """ @@ -983,7 +980,7 @@ def judge_template(question: str, answer: str, gold: str, options: Optional[List 2. التحقق من أن الإجابة مدعومة بالسياق المقدم 3. تقييم جودة وشمولية الإجابة -مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""" +مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""", }, { "role": "user", @@ -1000,8 +997,8 @@ def judge_template(question: str, answer: str, gold: str, options: Optional[List 7-8: إجابة جيدة 9-10: إجابة ممتازة -اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""" - } +اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""", + }, ] return messages @@ -1010,22 +1007,23 @@ def process_judge_response(response) -> float: """Process the judge's response to extract the score""" # If response is a list, extract the content from the user role if isinstance(response, list): - response_content = ' '.join(item['content'] for item in response if item['role'] == 'user') + response_content = " ".join(item["content"] for item in response if item["role"] == "user") else: response_content = response # If it's not a list, use it directly try: # Extract the score from the response content - score = float(next(num for num in response_content.split() if num.replace('.', '', 1).isdigit())) + score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit())) return min(max(score / 10.0, 0.0), 1.0) except (StopIteration, ValueError): return 0.0 + judge = JudgeLM( model="Qwen/Qwen2.5-72B-Instruct", templates=judge_template, process_judge_response=process_judge_response, - judge_backend="vllm" + judge_backend="vllm", ) wrapped_judge = JudgeMetricWrapper(judge) @@ -1043,7 +1041,7 @@ def process_judge_response(response) -> float: trust_dataset=True, generation_size=200, stop_sequence=[], - version=0 + version=0, ) TASKS_TABLE = ( diff --git a/examples/tasks/OALL_v2_tasks.txt b/examples/tasks/OALL_v2_tasks.txt index f2ac392f..176b662d 100644 --- a/examples/tasks/OALL_v2_tasks.txt +++ b/examples/tasks/OALL_v2_tasks.txt @@ -116,4 +116,3 @@ community|arabic_mmlu_ht:us_foreign_policy|0|0 community|arabic_mmlu_ht:virology|0|0 community|arabic_mmlu_ht:world_religions|0|0 community|alrage_qa|0|0 -