Skip to content

Commit

Permalink
Math judge for openrlhf (#365)
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok authored Feb 8, 2025
1 parent aa75742 commit aafacf3
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 397 deletions.
5 changes: 0 additions & 5 deletions nemo_skills/pipeline/openrlhf/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import os
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional

import nemo_run as run
Expand Down Expand Up @@ -89,7 +88,6 @@ def format_train_args(self):
f" --max_ckpt_mem 10000000000 "
f" --save_path {os.path.join(self.output_dir, 'checkpoints')} "
f" --save_steps -1 "
# f" --max_samples 500000 "
f" --max_epochs 1 "
f" --max_time_per_run {self.timeout} "
)
Expand Down Expand Up @@ -300,9 +298,6 @@ def ppo_openrlhf(
if prompt_data.startswith("/"): # could ask to download from HF
check_if_mounted(cluster_config, prompt_data)

if cluster_config["executor"] == "local":
assert "HF_HOME" in os.environ, "HF_HOME must be set when running locally"

# Check if custom PPOOpenRLHFTask is provided via ctx.obj['ppo_task'], use that if available
if hasattr(ctx, 'obj') and ctx.obj is not None and isinstance(ctx.obj, dict) and 'ppo_task' in ctx.obj:
ppo_task = ctx.obj['ppo_task'] # type: type(PPOOpenRLHFTask)
Expand Down
2 changes: 1 addition & 1 deletion nemo_skills/pipeline/openrlhf/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_cmd(cluster_config, params: TrainingParams):
f"cd /nemo_run/code && "
f"echo 'Starting SFT' && "
f'echo "Torch run cmd: {torchrun_cmd}" && '
f"{torchrun_cmd} -m nemo_skills.training.openrlhf.sft_script "
f"{torchrun_cmd} -m openrlhf.cli.train_sft "
f" {format_train_args(cluster_config, params)} "
f" {format_data_args(cluster_config, params)} "
f" {get_common_arg_overrides(cluster_config, params)} "
Expand Down
21 changes: 21 additions & 0 deletions nemo_skills/training/openrlhf/llm_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from nemo_skills.inference.server.model import get_model
from nemo_skills.prompt.utils import get_prompt
from nemo_skills.code_execution.math_grader import extract_answer
from nemo_skills.evaluation.metrics.utils import is_correct_judgement


def reward_func(queries: list[str], prompts: list[str], prompt_metadata: list[dict]):
expected_answers = [data["expected_answer"] for data in prompt_metadata]
predicted_answers = [extract_answer(query) for query in queries]
problems = [data["problem"] for data in prompt_metadata]
llm = get_model(server_type="trtllm")
prompt = get_prompt('judge/math', 'qwen-instruct')
prompts = [
prompt.fill({'problem': problem, 'expected_answer': expected_answer, 'predicted_answer': predicted_answer})
for problem, expected_answer, predicted_answer in zip(problems, expected_answers, predicted_answers)
]
outputs = llm.generate(prompts=prompts)
is_correct_array = [is_correct_judgement(output["generation"]) for output in outputs]

return torch.tensor(is_correct_array, dtype=torch.float32)
329 changes: 0 additions & 329 deletions nemo_skills/training/openrlhf/sft_script.py

This file was deleted.

Loading

0 comments on commit aafacf3

Please sign in to comment.