From ca8331a97879b18d4f49e5030abf9ae2731a3695 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Sat, 11 Jan 2025 03:36:46 -0800 Subject: [PATCH] Fixed gnarly bug with details loading to prevent loading too many examples. --- src/lighteval/pipeline.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index f023cd35..3b476453 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -297,27 +297,27 @@ def _load_responses_from_details(self): for task_name, dataset in tqdm(details_datasets.items(), desc="Loading responses from details for tasks"): task: LightevalTask = self._get_task(task_name) - num_samples = len(dataset["predictions"]) + num_samples = len(set(dataset["specifics"])) max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples if num_samples > max_samples: logger.warning( f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}" ) num_samples = self.pipeline_parameters.max_samples + + predictions = [ast.literal_eval(p) for p in dataset["predictions"][:num_samples]] + input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]] + cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]] + truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]] + padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]] + + if model_response_type == GenerativeResponse: + logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]] + for metric_category, has_metric_category in task.has_metric_category.items(): if not has_metric_category: continue - # Pre-evaluate all the literal strings once - predictions = [ast.literal_eval(p) for p in dataset["predictions"][:num_samples]] - input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]] - cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]] - truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]] - padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]] - - if model_response_type == GenerativeResponse: - logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]] - for idx in range(num_samples): kwargs = { "result": predictions[idx],