Skip to content

Commit

Permalink
Pass@k (#519)
Browse files Browse the repository at this point in the history
* init

* correct typing

* added defaults

* small fix
  • Loading branch information
clefourrier authored Feb 6, 2025
1 parent 15bdbb8 commit 441d7a4
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 1 deletion.
25 changes: 25 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
Faithfulness,
LoglikelihoodAcc,
MajAtK,
PassAtK,
Recall,
StringDistance,
acc_golds_likelihood,
Expand Down Expand Up @@ -369,6 +370,30 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
higher_is_better=True,
)
pass_at_1 = SampleLevelMetric(
metric_name="pass@1:32_samples",
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_10 = SampleLevelMetric(
metric_name="pass@10:32_samples",
sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_100 = SampleLevelMetric(
metric_name="pass@100:32_samples",
sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
perfect_exact_match = SampleLevelMetric(
metric_name="perfect_em",
sample_level_fn=ExactMatches().compute,
Expand Down
117 changes: 116 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import logging
import os
from typing import Callable, Literal
from typing import Callable, Literal, Union

import nltk
import numpy as np
Expand Down Expand Up @@ -1055,3 +1055,118 @@ def compute_score(self, pred: str, gold: str) -> int:
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0


class PassAtK:
def __init__(
self,
k: int,
n: int = None,
normalize_gold: Callable = None,
normalize_pred: Callable = None,
strip_strings: bool = False,
sample_scoring_function: Union[Callable[[str, str], float], str] = None,
):
"""Computing pass at k
Args:
k (int): Threshold for the number of successful attempts.
n (int): Number of samples to generate
normalize_gold (callable, optional): Function to use to normalize the reference strings.
Defaults to None if no normalization is applied.
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
sample_scoring_function (callable or str, optional): Function to use to score each sample.
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full".
`prefix` checks if the prediction starts with the gold,
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
self.k = k
self.n = n
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings

# Managed the logic of the per prediction of sample scoring
if callable(sample_scoring_function):
self.score_sample = sample_scoring_function
self.type_exact_match = None
else:
if isinstance(sample_scoring_function, str):
if sample_scoring_function not in ["prefix", "suffix", "full"]:
raise ValueError(
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
)
self.type_exact_match = sample_scoring_function
else:
self.type_exact_match = "full"
self.score_sample = self.default_sample_scoring

def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
then aggregates the scores over the samples using a pass@k.
Args:
golds (list[str]): Reference targets
predictions (list[str]): k predicted strings
Returns:
float: Aggregated score over the current sample's items.
"""
if len(golds) > 1:
raise Exception("Cannot compute pass@k with several golds")

if self.n is None:
self.n = len(predictions)
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.")
elif len(predictions) < self.n:
logger.warning(f"Number of predictions is less than {self.n} for pass@k.")

gold = self.get_processed_gold(golds[0])

all_scores = []
for pred in predictions[: self.n]:
cur_pred = self.get_processed_pred(pred=pred)
all_scores.append(self.score_sample(cur_pred, gold))

return self.pass_at_k(all_scores)

def get_processed_gold(self, gold: str) -> float:
if self.strip_strings:
gold = gold.strip()

if self.normalize_gold:
gold = self.normalize_gold(gold)

return gold

def get_processed_pred(self, pred: str) -> float:
if not pred:
return ""

if self.strip_strings:
pred = pred.strip()

if self.normalize_pred:
pred = self.normalize_pred(pred)

return pred

def default_sample_scoring(self, pred: str, gold: str) -> int:
if self.type_exact_match == "prefix":
return 1 if pred.startswith(gold) else 0
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0

def pass_at_k(self, all_scores: list[int]) -> float:
"""Algo from https://arxiv.org/pdf/2107.03374"""
c: int = all_scores.count(1)
if self.n - c < self.k:
return 1.0

return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))

0 comments on commit 441d7a4

Please sign in to comment.