diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index a114a5d4..0e6282ef 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -24,6 +24,7 @@ import collections import os import random +import re import shutil from contextlib import nullcontext from dataclasses import dataclass, field @@ -288,6 +289,79 @@ def _unpack(self, x): else: raise ValueError(f"Unknown type {type(x)} of prediction {x}") + def _parse_tensor_string(self, tensor_string): + """ + Convert a string containing PyTorch-like `tensor([...], device='cuda:0', ...)` + into a Python list (or nested lists) of numbers. + + Example: + "[tensor([1, 2, 3], device='cuda:0'), tensor([[4,5],[6,7]], dtype=torch.int64)]" + -> [[1, 2, 3], [[4, 5], [6, 7]]] + """ + + # Regex explanation: + # - tensor\(\s*: Matches "tensor(" (possibly with spaces after), literally. + # - (.*?): Captures everything lazily into group(1), until the first subsequent part matches. + # We rely on the next pattern to anchor the end of this capture. + # - \): The literal closing parenthesis, but we anchor the match by ignoring + # further arguments (device=..., dtype=..., etc.) inside. + # + # The tricky part: a tensor might look like + # tensor([ ... ], device='cuda:0', dtype=torch.int64) + # so the bracket portion is `[ ... ]`, but it can have newlines, etc. + # + # We'll handle that by first capturing the entire content up to the final parenthesis, + # then parse out the bracket portion. This can be done in a function-based re.sub. + + pattern = re.compile( + r"tensor\s*\(\s*(.*?)\s*\)", # capture everything inside tensor(...) + flags=re.DOTALL, + ) + + def tensor_replacer(match): + inside = match.group(1).strip() + # `inside` might look like: [1, 2, 3], device='cuda:0' + # or: + # [ + # 1, 2, 3, + # 4, 5, ... + # ], device='cuda:0', dtype=torch.int64 + # + # 1) Extract the bracketed array portion: the first [ ... ] block + # which might be multi-line. We'll use another regex for that. + + # We look for the bracketed portion from the first '[' to its matching ']'. + # Because the inside can be multi-line, we use DOTALL. But we still need + # to ensure we don't accidentally go beyond the matching bracket. + # + # A robust approach to properly match brackets can be done with a small parser, + # but for typical well-formed strings, a lazy match of the form + # r"\[.*?\]" DOTALL often suffices, assuming no nested brackets inside. + + bracket_pattern = re.compile(r"\[.*?\]", re.DOTALL) + bracket_match = bracket_pattern.search(inside) + if not bracket_match: + # If we fail to find a bracket, just return something safe. + # This means the string didn't match the expected format. + return "[]" + + # The bracketed portion (e.g. "[1, 2, 3\n, 4]"). + bracketed_content = bracket_match.group(0) + + # Return just the bracketed content, + # effectively replacing "tensor(...)" with "[...]". + return bracketed_content + + # Step 1: Replace every `tensor(...)` occurrence with just the bracketed list. + processed = pattern.sub(tensor_replacer, tensor_string) + + # Step 2: Now we can safely parse the result with literal_eval. + # If there's still something weird, it may throw ValueError. + try: + return ast.literal_eval(processed) + except Exception as e: + raise ValueError(f"Failed to parse after preprocessing. " f"Processed string:\n{processed}\n\nError: {e}") + def _load_responses_from_details(self): logger.info("--- LOADING RESPONSES FROM DETAILS ---") sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list) @@ -314,8 +388,8 @@ def _load_responses_from_details(self): num_samples = self.pipeline_parameters.max_samples predictions = [self._unpack(ast.literal_eval(p)) for p in dataset["predictions"][:num_samples]] - input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]] - cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]] + input_tokens = [self._parse_tensor_string(t) for t in dataset["input_tokens"][:num_samples]] + cont_tokens = [self._parse_tensor_string(t) for t in dataset["cont_tokens"][:num_samples]] truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]] padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]]