Skip to content

Commit

Permalink
feat: support selective eval task
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyz committed Jan 14, 2025
1 parent 8cdcdfe commit 342aed8
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions bigcodebench/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def evaluate(
samples: Optional[str] = None,
no_execute: bool = False,
local_execute: bool = False,
selective_evaluate: str = "",
remote_execute_api: str = "https://bigcode-bigcodebench-evaluator.hf.space/",
pass_k: str = "1,5,10",
save_pass_rate: bool = True,
Expand Down Expand Up @@ -168,6 +169,7 @@ def evaluate(
calibrated=calibrated,
check_gt_only=check_gt_only,
no_gt=no_gt,
selective_evaluate=selective_evaluate,
api_name="/predict"
)
break
Expand All @@ -193,6 +195,14 @@ def evaluate(
samples = "__dummy__.jsonl"

problems = get_bigcodebench(subset=subset)

# Add selective evaluation logic
if selective_evaluate:
selected_ids = set(selective_evaluate.split(","))
problems = {k: v for k, v in problems.items() if k in selected_ids}
if not problems:
raise ValueError(f"None of the provided task IDs {selected_ids} were found in the dataset")

dataset_hash = get_bigcodebench_hash(subset=subset)

if not no_gt:
Expand Down Expand Up @@ -240,10 +250,9 @@ def evaluate(
task_id = sample["task_id"]

if task_id not in problems:
warn(
f"Task {task_id} is found in the samples but not found in the dataset"
)
# Skip if task is not in problems (either not in dataset or filtered out by selective_evaluate)
continue

solution = (
sample["solution"]
if "solution" in sample
Expand All @@ -267,8 +276,10 @@ def evaluate(
completion_id[task_id] += 1
n_samples += 1

# Modify the assertion to account for selective evaluation
assert n_samples == len(remainings), "Missing problems in unfinished"
assert len(completion_id) == len(problems), "Missing problems in samples"
# Only check against problems that weren't filtered out
assert len(completion_id) == len(problems), f"Missing problems in samples. Expected {len(problems)} problems, got {len(completion_id)}"

def stucking_checker():
while remainings:
Expand Down

0 comments on commit 342aed8

Please sign in to comment.