diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 6cad9189f..0fe638c30 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -209,9 +209,45 @@ def save_results(self, date_id: str, results_dict: dict): with self.fs.open(output_results_file, "w") as f: f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)) - def save_details(self, date_id: str, details_datasets: dict[str, Dataset]): + def _get_details_sub_folder(self, date_id: str): output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name - output_dir_details_sub_folder = output_dir_details / date_id + if date_id == "latest": + # Get all folders in output_dir_details + if not self.fs.exists(output_dir_details): + raise FileNotFoundError(f"Details directory {output_dir_details} does not exist") + + # List all folders and filter out files + folders = [f["name"] for f in self.fs.listdir(output_dir_details) if f["type"] == "directory"] + + if not folders: + raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}") + + # Parse timestamps and get latest + date_id = max(folders) + return output_dir_details / date_id + + def load_details_datasets(self, date_id: str, task_names: list[str]) -> dict[str, Dataset]: + output_dir_details_sub_folder = self._get_details_sub_folder(date_id) + logger.info(f"Loading details from {output_dir_details_sub_folder}") + date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest + details_datasets = {} + for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")): + task_name = Path(file).stem.replace("details_", "").replace(f"_{date_id}", "") + if "|".join(task_name.split("|")[:-1]) not in task_names: + logger.info(f"Skipping {task_name} because it is not in the task_names list") + continue + dataset = load_dataset("parquet", data_files=file, split="train") + details_datasets[task_name] = dataset + + for task_name in task_names: + if not any(task_name.startswith(task_name) for task_name in details_datasets.keys()): + raise ValueError( + f"Task {task_name} not found in details datasets. Check the tasks to be evaluated or the date_id used to load the details ({date_id})." + ) + return details_datasets + + def save_details(self, date_id: str, details_datasets: dict[str, Dataset]): + output_dir_details_sub_folder = self._get_details_sub_folder(date_id) self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True) logger.info(f"Saving details to {output_dir_details_sub_folder}") for task_name, dataset in details_datasets.items(): diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index fe7f98d6f..d8d69f30f 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -67,6 +67,9 @@ def accelerate( # noqa C901 num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -137,6 +140,7 @@ def accelerate( # noqa C901 max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) # TODO (nathan): better handling of model_args diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 2c51fe15f..858cdcde3 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -179,6 +179,9 @@ def inference_endpoint( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -247,6 +250,7 @@ def inference_endpoint( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) pipeline = Pipeline( tasks=tasks, @@ -292,6 +296,9 @@ def tgi( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -355,6 +362,7 @@ def tgi( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) pipeline = Pipeline( tasks=tasks, @@ -400,6 +408,9 @@ def litellm( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -464,6 +475,7 @@ def litellm( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) pipeline = Pipeline( tasks=tasks, diff --git a/src/lighteval/main_vllm.py b/src/lighteval/main_vllm.py index 89311b5ae..d063c3fa8 100644 --- a/src/lighteval/main_vllm.py +++ b/src/lighteval/main_vllm.py @@ -63,6 +63,9 @@ def vllm( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -124,6 +127,7 @@ def vllm( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) if model_args.endswith(".yaml"): diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 6a40d2801..ed8d71457 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import ast import collections import os import random @@ -28,16 +29,28 @@ from dataclasses import dataclass, field from datetime import timedelta from enum import Enum, auto +from typing import Dict import numpy as np +from tqdm import tqdm from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.utils.metric_utils import MetricCategory +from lighteval.models.abstract_model import ModelInfo from lighteval.models.model_loader import TransformersModel, load_model -from lighteval.models.model_output import ModelResponse +from lighteval.models.model_output import ( + GenerativeMultiturnResponse, + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, + ModelResponse, +) +from lighteval.models.transformers.transformers_model import TransformersModelConfig +from lighteval.models.utils import _simplify_name +from lighteval.models.vllm.vllm_model import VLLMModelConfig from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks from lighteval.tasks.registry import Registry, taskinfo_selector -from lighteval.tasks.requests import SampleUid +from lighteval.tasks.requests import RequestType, SampleUid from lighteval.utils.imports import ( NO_ACCELERATE_ERROR_MSG, NO_NANOTRON_ERROR_MSG, @@ -95,6 +108,7 @@ class PipelineParameters: max_samples: int | None = None use_chat_template: bool = False system_prompt: str | None = None + load_responses_from_details_date_id: str | None = None def __post_init__(self): # noqa C901 if self.launcher_type == ParallelismManager.ACCELERATE: @@ -133,42 +147,49 @@ def __init__( "--max_samples WAS SET. THESE NUMBERS ARE ONLY PARTIAL AND SHOULD NOT BE USED FOR COMPARISON UNLESS YOU KNOW WHAT YOU ARE DOING." ) + self.tasks = tasks + self.model = model self.model_config = model_config self.evaluation_tracker = evaluation_tracker - self.accelerator, self.parallel_context = self._init_parallelism_manager() - self.model = self._init_model(model_config, model) - - self.evaluation_tracker.general_config_logger.log_model_info(self.model.model_info) - self._init_tasks_and_requests(tasks=tasks) + self._init_parallelism_manager() self._init_random_seeds() + + self.evaluation_tracker.general_config_logger.log_model_info(self._get_model_info()) + self._init_tasks() + # Final results self.final_dict: dict = None + def _get_model_info(self): + if isinstance(self.model_config, (VLLMModelConfig, TransformersModelConfig)): + # At this point we only need the model name to know the details path + return ModelInfo(model_name=_simplify_name(self.model_config.pretrained)) + else: + return self._init_model().model_info + def _init_parallelism_manager(self): - accelerator, parallel_context = None, None + self.accelerator, self.parallel_context = None, None if self.launcher_type == ParallelismManager.ACCELERATE: if not is_accelerate_available(): raise ValueError("You are trying to launch an accelerate model, but accelerate is not installed") - accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) - test_all_gather(accelerator=accelerator) + self.accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) + test_all_gather(accelerator=self.accelerator) elif self.launcher_type == ParallelismManager.NANOTRON: if not is_nanotron_available(): raise ValueError("You are trying to launch a nanotron model, but nanotron is not installed") dist.initialize_torch_distributed() - parallel_context = ParallelContext( + self.parallel_context = ParallelContext( tensor_parallel_size=self.model_config.lighteval_config.parallelism.tp, pipeline_parallel_size=self.model_config.lighteval_config.parallelism.pp, data_parallel_size=self.model_config.lighteval_config.parallelism.dp, ) - test_all_gather(parallel_context=parallel_context) - - return accelerator, parallel_context + test_all_gather(parallel_context=self.parallel_context) - def _init_model(self, model_config, model): + def _init_model(self): logger.info("--- LOADING MODEL ---") - if model_config is not None: + if self.model_config is not None: if self.parallel_context: - return NanotronLightevalModel( + self.model = NanotronLightevalModel( checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path) if self.pipeline_parameters.nanotron_checkpoint_path else "", @@ -179,46 +200,42 @@ def _init_model(self, model_config, model): env_config=self.pipeline_parameters.env_config, ) else: - return load_model(config=model_config, env_config=self.pipeline_parameters.env_config) - if isinstance(model, TransformersModel): - return model - else: - return TransformersModel.from_model( - model=model, + self.model = load_model(config=self.model_config, env_config=self.pipeline_parameters.env_config) + if not isinstance(self.model, TransformersModel): + self.model = TransformersModel.from_model( + model=self.model, use_chat_template=self.pipeline_parameters.use_chat_template, env_config=self.pipeline_parameters.env_config, accelerator=self.accelerator, ) + return self.model - def _init_tasks_and_requests(self, tasks: str): + def _init_tasks(self): with local_ranks_zero_first() if self.launcher_type == ParallelismManager.NANOTRON else nullcontext(): logger.info("--- LOADING TASKS ---") registry = Registry( cache_dir=self.pipeline_parameters.env_config.cache_dir, custom_tasks=self.pipeline_parameters.custom_tasks_directory, ) - task_names_list, fewshots_dict = taskinfo_selector(tasks, registry) - task_dict = registry.get_task_dict(task_names_list) - LightevalTask.load_datasets(list(task_dict.values()), self.pipeline_parameters.dataset_loading_processes) - - self.evaluation_tracker.task_config_logger.log(task_dict) - - requests, docs = create_requests_from_tasks( - task_dict=task_dict, - fewshot_dict=fewshots_dict, - num_fewshot_seeds=self.pipeline_parameters.num_fewshot_seeds, - lm=self.model, - max_samples=self.pipeline_parameters.max_samples, - evaluation_tracker=self.evaluation_tracker, - use_chat_template=self.pipeline_parameters.use_chat_template, - system_prompt=self.pipeline_parameters.system_prompt, + self.task_names_list, self.fewshots_dict = taskinfo_selector(self.tasks, registry) + self.task_dict = registry.get_task_dict(self.task_names_list) + LightevalTask.load_datasets( + list(self.task_dict.values()), self.pipeline_parameters.dataset_loading_processes ) - self.task_names_list = task_names_list - self.task_dict = task_dict - self.fewshot_dict = fewshots_dict - self.requests = requests - self.docs = docs + self.evaluation_tracker.task_config_logger.log(self.task_dict) + + def _init_requests(self): + self.requests, self.docs = create_requests_from_tasks( + task_dict=self.task_dict, + fewshot_dict=self.fewshots_dict, + num_fewshot_seeds=self.pipeline_parameters.num_fewshot_seeds, + lm=self.model, + max_samples=self.pipeline_parameters.max_samples, + evaluation_tracker=self.evaluation_tracker, + use_chat_template=self.pipeline_parameters.use_chat_template, + system_prompt=self.pipeline_parameters.system_prompt, + ) def _init_random_seeds(self): logger.info("--- INIT SEEDS ---") @@ -245,7 +262,17 @@ def evaluate(self): config=self.model_config, ) - sample_id_to_responses = self._run_model() + if self.pipeline_parameters.load_responses_from_details_date_id: + try: + sample_id_to_responses = self._load_responses_from_details() + except FileNotFoundError as e: + logger.warning( + f"No responses found for {self.pipeline_parameters.load_responses_from_details_date_id} in details directory: {e}. Running model instead." + ) + sample_id_to_responses = self._run_model() + else: + sample_id_to_responses = self._run_model() + self._compute_metrics(sample_id_to_responses) if self.is_main_process(): @@ -261,7 +288,99 @@ def evaluate(self): except OSError: pass + @staticmethod + def _metric_category_to_request_type() -> Dict[MetricCategory, RequestType]: + """Maps MetricCategory to their corresponding RequestType.""" + return { + MetricCategory.TARGET_PERPLEXITY: RequestType.LOGLIKELIHOOD, + MetricCategory.PERPLEXITY: RequestType.LOGLIKELIHOOD_ROLLING, + MetricCategory.GENERATIVE_SAMPLING: RequestType.GREEDY_UNTIL, + MetricCategory.GENERATIVE: RequestType.GREEDY_UNTIL, + MetricCategory.GENERATIVE_LOGPROB: RequestType.GREEDY_UNTIL, + MetricCategory.MULTICHOICE: RequestType.LOGLIKELIHOOD, + MetricCategory.MULTICHOICE_PMI: RequestType.LOGLIKELIHOOD, + MetricCategory.MULTICHOICE_ONE_TOKEN: RequestType.LOGLIKELIHOOD_SINGLE_TOKEN, + MetricCategory.LLM_AS_JUDGE_MULTI_TURN: RequestType.GREEDY_UNTIL_MULTI_TURN, + MetricCategory.LLM_AS_JUDGE: RequestType.GREEDY_UNTIL, + } + + @staticmethod + def _request_type_to_response() -> Dict[RequestType, type[ModelResponse]]: + return { + RequestType.LOGLIKELIHOOD: LoglikelihoodResponse, + RequestType.LOGLIKELIHOOD_SINGLE_TOKEN: LoglikelihoodSingleTokenResponse, + RequestType.LOGLIKELIHOOD_ROLLING: LoglikelihoodResponse, + RequestType.GREEDY_UNTIL_MULTI_TURN: GenerativeMultiturnResponse, + RequestType.GREEDY_UNTIL: GenerativeResponse, + } + + def _load_responses_from_details(self): + logger.info("--- LOADING RESPONSES FROM DETAILS ---") + sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list) + + model_response_type = self._get_model_response_type() + + details_datasets = self.evaluation_tracker.load_details_datasets( + self.pipeline_parameters.load_responses_from_details_date_id, self.task_names_list + ) + + for task_name, dataset in tqdm(details_datasets.items(), desc="Loading responses from details for tasks"): + task: LightevalTask = self._get_task(task_name) + num_samples = len(set(dataset["specifics"])) + max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples + if num_samples > max_samples: + logger.warning( + f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}" + ) + num_samples = self.pipeline_parameters.max_samples + + predictions = [ast.literal_eval(p) for p in dataset["predictions"][:num_samples]] + input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]] + cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]] + truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]] + padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]] + + if model_response_type == GenerativeResponse: + logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]] + + for metric_category, has_metric_category in task.has_metric_category.items(): + if not has_metric_category: + continue + + for idx in range(num_samples): + kwargs = { + "result": predictions[idx], + "input_tokens": input_tokens[idx], + "generated_tokens": cont_tokens[idx], + "truncated_tokens_count": truncated[idx], + "padded_tokens_count": padded[idx], + } + if model_response_type == GenerativeResponse: + kwargs["logits"] = logits[idx] + + response = model_response_type(**kwargs) + sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response] + return sample_id_to_responses + + def _get_model_response_type(self): + model_response_type = None + for task in self.task_dict.values(): + for metric_category, has_metric_category in task.has_metric_category.items(): + if has_metric_category: + request_type = self._metric_category_to_request_type()[metric_category] + new_model_response_type = self._request_type_to_response()[request_type] + if model_response_type and new_model_response_type != model_response_type: + raise ValueError( + f"Loading responses from details with multiple model response types ({model_response_type} and {new_model_response_type}) is currently not supported" + ) + model_response_type = new_model_response_type + return model_response_type + def _run_model(self): + # Initi model stuff lazily to avoid loading the model if not needed + self._init_model() + self._init_requests() # Needs the model to be initialized + # Running all requests depending on the model call type (log likelihood, generative, ...) # to be able to batch them logger.info("--- RUNNING MODEL ---") @@ -283,6 +402,10 @@ def _run_model(self): return sample_id_to_responses + def _get_task(self, task_name: str): + short_task_name = task_name.rsplit("|", 1)[0] + return self.task_dict[short_task_name] + def _compute_metrics(self, sample_id_to_responses): # To compute the metrics we first group the samples and task and then by metrics. # This way we can batch the metrics computation for each task and metric category @@ -307,8 +430,7 @@ def _compute_metrics(self, sample_id_to_responses): task_metric_category_groups[sample_id.task_name][metric_category]["docs"].append(self.docs[sample_id]) for task_name, samples_per_metric in task_metric_category_groups.items(): - short_task_name = task_name.rsplit("|", 1)[0] - task: LightevalTask = self.task_dict[short_task_name] + task: LightevalTask = self._get_task(task_name) for metric_category, samples in samples_per_metric.items(): sample_ids = samples["ids"]