Skip to content

Commit

Permalink
Async generation to save results as soon as they are done (#327)
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok authored Jan 24, 2025
1 parent 549d2a6 commit 8611a8f
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 103 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/gpu_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python -m pip install --upgrade pip
pip install -e .
pip install -r requirements/common-tests.txt
python -m nemo_skills.dataset.prepare
python -m nemo_skills.dataset.prepare gsm8k human-eval mbpp algebra222 mmlu ifeval
- name: Run GPU tests
timeout-minutes: 120
env:
Expand Down Expand Up @@ -68,12 +68,14 @@ jobs:
with:
python-version: "3.10"
- name: Install dependencies
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
cd ${{ github.run_id }}
python -m pip install --upgrade pip
pip install -e .
pip install -r requirements/common-tests.txt
python -m nemo_skills.dataset.prepare
python -m nemo_skills.dataset.prepare gsm8k human-eval mbpp algebra222 mmlu ifeval
- name: Run GPU tests
timeout-minutes: 120
env:
Expand Down
209 changes: 145 additions & 64 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import logging
import sys
import time
from copy import deepcopy
from dataclasses import asdict, field
from pathlib import Path
Expand Down Expand Up @@ -130,39 +131,8 @@ def combine_stop_phrases(prompt_phrases, extra_phrases):
return prompt_phrases + extra_phrases


@hydra.main(version_base=None, config_name='base_generation_config')
def generate(cfg: GenerateSolutionsConfig):
cfg = GenerateSolutionsConfig(_init_nested=True, **cfg)

LOG.info("Config used: %s", cfg)

if cfg.prompt_template is None and cfg.server["server_type"] != "openai":
# TODO: handle this explicitly in model.py clients
# switching to OpenAI client always if prompt template is not provided
with open_dict(cfg.server):
cfg.server["server_type"] = "openai"
cfg.server["model"] = "model"
if cfg.code_execution:
raise ValueError("Code execution is not supported for OpenAI server")

if cfg.code_execution:
sandbox = get_sandbox(**cfg.sandbox) if cfg.sandbox is not None else None
llm = get_code_execution_model(**cfg.server, sandbox=sandbox)
else:
llm = get_model(**cfg.server)

# making sure output dir exists
Path(cfg.output_file).absolute().parent.mkdir(parents=True, exist_ok=True)

# we currently assume the dataset is small enough to be loaded into memory
data = []
with open(cfg.input_file, "rt", encoding="utf-8") as fin:
for line in fin:
data.append(json.loads(line))

# skipping based on the offset first
data = data[cfg.offset :]

def sync_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params):
"""Synchronous version that only dumps data when full batch is finished."""
starting_idx = 0
if cfg.skip_filled:
try:
Expand All @@ -173,13 +143,6 @@ def generate(cfg: GenerateSolutionsConfig):

# additionally, skipping whatever is pre-filled, assuming offset didn't change
data = data[starting_idx:]
if cfg.prompt_config is None:
# fetching from the default for corresponding dataset
dataset_module = importlib.import_module(f"nemo_skills.dataset.{cfg.dataset}")
cfg.prompt_config = dataset_module.PROMPT_CONFIG

prompt = get_prompt(cfg.prompt_config, cfg.prompt_template, examples_type=cfg.examples_type)
LOG.info("Prompt used: %s", prompt)

# need to account for anything that's prefilled
if 0 <= cfg.max_samples <= starting_idx:
Expand All @@ -194,37 +157,15 @@ def generate(cfg: GenerateSolutionsConfig):
if len(data) == 0: # we might not have any examples if skip_filled=True
return

if cfg.multi_turn_key is None:
LOG.info("Example prompt:\nData dictionary: %s\nPrompt: %s", data[0], prompt.fill(data[0]))
else:
first_sample = deepcopy(data[0])
first_sample[cfg.multi_turn_key] = first_sample[cfg.multi_turn_key][:1]
LOG.info(
"Example prompt (first turn only):\nData dictionary: %s\nPrompt: %s",
first_sample,
prompt.fill(first_sample, multi_turn_key=cfg.multi_turn_key),
)

if cfg.dry_run:
return

# if using code execution, we need some extra parameters for generate call
if cfg.code_execution:
extra_generate_params = prompt.get_code_execution_args()
else:
extra_generate_params = {}

extra_stop_phrases = OmegaConf.to_container(cfg.extra_stop_phrases, resolve=True)
data = data[: cfg.max_samples]

# setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations
with open(cfg.output_file, "at" if cfg.skip_filled else "wt", encoding="utf-8", buffering=1) as fout:
data_points = []
for idx, data_point in tqdm(enumerate(data), initial=starting_idx, total=len(data) + starting_idx):
if idx >= cfg.max_samples:
break
data_points.append(data_point)

if len(data_points) == cfg.batch_size or idx == cfg.max_samples - 1:
if len(data_points) == cfg.batch_size or idx == len(data) - 1:
if cfg.multi_turn_key is None:
outputs = llm.generate(
prompts=[prompt.fill(dp) for dp in data_points],
Expand Down Expand Up @@ -284,6 +225,146 @@ def generate(cfg: GenerateSolutionsConfig):
data_points = []


def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params):
"""Asynchronous version that sends all the data to server and then dumps outputs as soon as they are finished."""
if cfg.max_samples > 0:
data = data[: cfg.max_samples]
original_positions = [idx for idx in range(len(data))]
if cfg.skip_filled:
try:
filled_positions = set()
with open(cfg.output_file + '-async', "rt", encoding="utf-8") as fin:
for line in fin:
filled_positions.add(int(json.loads(line)[0]))
data = [dp for idx, dp in enumerate(data) if idx not in filled_positions]
original_positions = [idx for idx in original_positions if idx not in filled_positions]
except FileNotFoundError:
LOG.warning(f"File `{cfg.output_file}-async` not found, starting from scratch")

# if output_file (not async) exists, means we are fully done or need to remove it and regenerate
if Path(cfg.output_file).exists():
if not cfg.skip_filled:
Path(cfg.output_file).unlink()
else:
return

if len(data) == 0: # we might not have any examples if skip_filled=True
return

# submitting all data at ones
generation_ids = llm.generate_async(
prompts=[prompt.fill(dp) for dp in data],
stop_phrases=combine_stop_phrases(prompt.stop_phrases, extra_stop_phrases),
**asdict(cfg.inference),
**extra_generate_params,
)

# setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations
with open(cfg.output_file + '-async', "at" if cfg.skip_filled else "wt", encoding="utf-8", buffering=1) as fout:
while True:
remaining_ids = [generation_id for generation_id in generation_ids if generation_id is not None]
if len(remaining_ids) == 0:
break
remaining_positions = [
idx for idx, generation_id in enumerate(generation_ids) if generation_id is not None
]
generations = llm.get_generations(remaining_ids)
for gen_pos, gen_dict in zip(remaining_positions, generations):
if gen_dict['generation'] is not None: # will be None until done
generation_ids[gen_pos] = None
gen_dict[cfg.generation_key] = gen_dict.pop("generation")
for key in gen_dict:
data[gen_pos].pop(key, None)
gen_dict.update(data[gen_pos])
fout.write(json.dumps([original_positions[gen_pos], gen_dict]) + "\n")

time.sleep(1)

# after we are done, need to restore the order and resave without position ids
with open(cfg.output_file + '-async', "rt", encoding="utf-8") as fin:
generations = [json.loads(line) for line in fin]

ordered_generations = [None] * len(generations)
for gen_dict in generations:
ordered_generations[gen_dict[0]] = gen_dict[1]

with open(cfg.output_file, "wt", encoding="utf-8") as fout:
for gen_dict in ordered_generations:
fout.write(json.dumps(gen_dict) + "\n")

Path(cfg.output_file + '-async').unlink()


@hydra.main(version_base=None, config_name='base_generation_config')
def generate(cfg: GenerateSolutionsConfig):
cfg = GenerateSolutionsConfig(_init_nested=True, **cfg)

LOG.info("Config used: %s", cfg)

if cfg.prompt_template is None and cfg.server["server_type"] != "openai":
# TODO: handle this explicitly in model.py clients
# switching to OpenAI client always if prompt template is not provided
with open_dict(cfg.server):
cfg.server["server_type"] = "openai"
cfg.server["model"] = "model"
if cfg.code_execution:
raise ValueError("Code execution is not supported for OpenAI server")

if cfg.code_execution:
sandbox = get_sandbox(**cfg.sandbox) if cfg.sandbox is not None else None
llm = get_code_execution_model(**cfg.server, sandbox=sandbox)
else:
llm = get_model(**cfg.server)

# making sure output dir exists
Path(cfg.output_file).absolute().parent.mkdir(parents=True, exist_ok=True)

# we currently assume the dataset is small enough to be loaded into memory
data = []
with open(cfg.input_file, "rt", encoding="utf-8") as fin:
for line in fin:
data.append(json.loads(line))

# skipping based on the offset first
data = data[cfg.offset :]

if cfg.prompt_config is None:
# fetching from the default for corresponding dataset
dataset_module = importlib.import_module(f"nemo_skills.dataset.{cfg.dataset}")
cfg.prompt_config = dataset_module.PROMPT_CONFIG

prompt = get_prompt(cfg.prompt_config, cfg.prompt_template, examples_type=cfg.examples_type)
LOG.info("Prompt used: %s", prompt)

if cfg.multi_turn_key is None:
LOG.info("Example prompt:\nData dictionary: %s\nPrompt: %s", data[0], prompt.fill(data[0]))
else:
first_sample = deepcopy(data[0])
first_sample[cfg.multi_turn_key] = first_sample[cfg.multi_turn_key][:1]
LOG.info(
"Example prompt (first turn only):\nData dictionary: %s\nPrompt: %s",
first_sample,
prompt.fill(first_sample, multi_turn_key=cfg.multi_turn_key),
)

if cfg.dry_run:
return

# if using code execution, we need some extra parameters for generate call
if cfg.code_execution:
extra_generate_params = prompt.get_code_execution_args()
else:
extra_generate_params = {}

extra_stop_phrases = OmegaConf.to_container(cfg.extra_stop_phrases, resolve=True)

# for nemo or multi-turn generation, we don't support async yet
if cfg.server["server_type"] == "nemo" or cfg.multi_turn_key is not None:
sync_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params)
else:
async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params)


HELP_MESSAGE = get_help_message(
GenerateSolutionsConfig,
server_params=server_params(),
Expand Down
Loading

0 comments on commit 8611a8f

Please sign in to comment.