Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add PRM and refactor MCTS #6119

Closed
wants to merge 22 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/ColossalChat/coati/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
StatefulDistributedSampler,
load_tokenized_dataset,
)
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from .tokenization_utils import tokenize_kto, tokenize_process_reward, tokenize_prompt, tokenize_rlhf, tokenize_sft

__all__ = [
"tokenize_prompt",
@@ -23,4 +23,5 @@
"tokenize_kto",
"setup_conversation_template",
"Conversation",
"tokenize_process_reward",
]
59 changes: 34 additions & 25 deletions applications/ColossalChat/coati/dataset/conversation.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
import dataclasses
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List

import torch.distributed as dist
from transformers import AutoTokenizer, PreTrainedTokenizer

from colossalai.logging import get_dist_logger

logger = get_dist_logger()


@dataclasses.dataclass
@dataclass
class Conversation:
tokenizer: PreTrainedTokenizer
system_message: str
chat_template: str
stop_ids: List[int]
end_of_assistant: str
roles = ["user", "assistant"]
messages: List[Dict[str, str]] = field(default_factory=list)
roles: List[str] = field(default_factory=lambda: ["user", "assistant"])
step_score_signal: str = None
reward_signal: List[str] = None

@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
"""
Setup the conversation template from config
"""
tokenizer.chat_template = config["chat_template"]
conv = cls(
tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"]
)
conv.clear()
return conv
conversation = cls(tokenizer, **config)

special_tokens = []
if conversation.step_score_signal is not None:
special_tokens.extend(conversation.step_score_signal)

if conversation.reward_signal is not None:
special_tokens.extend(conversation.reward_signal)

if special_tokens:
conversation.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

return conversation

def clear(self):
self.messages = []

@classmethod
def get_conversation_template_keys(cls):
return ["system_message", "chat_template"]
return ["system_message", "chat_template", "end_of_assistant"]

def __str__(self):
return json.dumps(
@@ -46,35 +56,32 @@ def __str__(self):
indent=4,
)

def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any:
def get_prompt(self, num_messages: int = None, add_generation_prompt=False) -> Any:
"""
Retrieves the prompt for the conversation.
Args:
length (int, optional): The number of messages to include in the prompt. Defaults to None.
num_messages (int, optional): The number of messages to include in the prompt. Defaults to None.
get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False.
add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False.
Returns:
str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information.
"""

if length is None:
length = len(self.messages)
if num_messages is None:
num_messages = len(self.messages)

assert length <= len(self.messages)
assert num_messages <= len(self.messages)
if self.system_message is not None:
messages = [{"role": "system", "content": self.system_message}] + self.messages[:length]
messages = [{"role": "system", "content": self.system_message}] + self.messages[:num_messages]
else:
messages = self.messages[:length]
messages = self.messages[:num_messages]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return prompt

def save_prompt(self):
return self.get_prompt()

def append_message(self, role: str, message: str):
"""
Append a message to the conversation.
@@ -141,9 +148,11 @@ def setup_conversation_template(
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)

os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)
f.write("\n")

return Conversation.from_config(tokenizer, chat_template_config)
46 changes: 45 additions & 1 deletion applications/ColossalChat/coati/dataset/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tokenization utils for constructing dataset for ppo, dpo, sft, rm
Tokenization Utils for Constructing Dataset for RL.
"""

import warnings
from copy import deepcopy
from typing import Any, Dict, List, Union

import torch
from coati.dataset.conversation import Conversation
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
from datasets import dataset_dict
@@ -393,3 +394,46 @@ def tokenize_kto(
"input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion,
}


def tokenize_process_reward(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize function designed for tokenizing Math-Shepherd dataset.
The datapoint has the following format:
{
"input": problem + step-by-step solution,
"label": problem + step-by-step solution with automatic label,
"task": GSM8K or MATH
}
"""
input = data_point["input"]
label = data_point["label"]

template = deepcopy(conversation_template)
template.append_message("user", input)
template.append_message("assistant", label)
prompt = template.get_prompt(add_generation_prompt=False)
reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal)
tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"]

tokenized_tensor = torch.tensor(tokenized)
loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id))

label = (tokenized_tensor * loss_mask).tolist()
decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False)
decoded_label = tokenizer.decode(label, skip_special_tokens=False)

return {
"input_ids": tokenized,
"labels": label,
"loss_mask": loss_mask,
"decoded_input": decoded_input,
"decoded_label": decoded_label,
}
6 changes: 4 additions & 2 deletions applications/ColossalChat/coati/models/__init__.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, PRMLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout

@@ -18,9 +18,11 @@
"lora_manager",
"convert_to_lora_module",
"DpoLoss",
"KTOLoss" "generate",
"KTOLoss",
"generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",
"prepare_inputs_fn",
"PRMLoss",
]
21 changes: 21 additions & 0 deletions applications/ColossalChat/coati/models/loss.py
Original file line number Diff line number Diff line change
@@ -280,3 +280,24 @@ def forward(
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()

return losses, chosen_rewards, rejected_rewards, kl


class PRMLoss(nn.Module):
def __init__(self, reward_signal_id: Optional[list[int]] = None):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
self.reward_signal_id = reward_signal_id

def forward(self, labels: torch.Tensor, logits: torch.Tensor):
loss_mask = torch.isin(labels, torch.tensor(self.reward_signal_id).to(labels.device))

logits = logits[loss_mask]
labels = labels[loss_mask]
logits = logits[..., self.reward_signal_id]

label_mapping = {token: i for i, token in enumerate(self.reward_signal_id)}
labels = torch.tensor([label_mapping.get(label.item(), label.item()) for label in labels], device=labels.device)
loss = self.loss(logits, labels)

return loss
12 changes: 5 additions & 7 deletions applications/ColossalChat/coati/reasoner/guided_search/mcts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
Implementation of MCTS + Self-refine algorithm.
Reference:
Structure is adapted from https://github.com/BrendanGraham14/mcts-llm/ with the following reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""
@@ -121,16 +120,15 @@ def simulate(self):

return self.get_best_answer()

def get_best_answer(self):
def _iter_nodes(self):
to_visit = deque([self.root])
best_node = self.root

while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
yield current_node
to_visit.extend(current_node.children)

def get_best_answer(self):
best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root)
return best_node.answer

def self_refine(self, node: MCTSNode):
2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
from .prm import ProcessRewardModelTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer

@@ -15,4 +16,5 @@
"DPOTrainer",
"ORPOTrainer",
"KTOTrainer",
"ProcessRewardModelTrainer",
]
159 changes: 159 additions & 0 deletions applications/ColossalChat/coati/trainer/prm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
Trainer for Process Reward Model.
"""

import os
import time
from typing import Any, Callable, List, Optional

import torch
import tqdm
from coati.models import PRMLoss
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device

from .base import SLTrainer
from .utils import is_rank_0, to_device


class ProcessRewardModelTrainer(SLTrainer):
"""
Trainer for Process Reward Model.
"""

def __init__(
self,
model: Any,
booster: Booster,
optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
max_epochs: int = 1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
reward_signal_ids: List[int] = [],
) -> None:
super().__init__(
booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch
)
self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.reward_signal_ids = reward_signal_ids
self.loss_fn = loss_fn if loss_fn is not None else PRMLoss(self.reward_signal_ids)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()

def _before_fit(
self,
train_dataloader: DataLoader = None,
eval_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.writer = None
if log_dir is not None and is_rank_0():
from torch.utils.tensorboard import SummaryWriter

log_dir = os.path.join(log_dir, "PRM", time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)

if use_wandb:
import wandb

self.wandb_run = wandb.init(project="Coati-PRM", sync_tensorboard=True)

def _train(self, epoch: int):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
batch_size = batch["input_ids"].size(0)
logits = self.model(batch["input_ids"])["logits"]
loss = self.loss_fn(batch["labels"], logits)
self.booster.backward(loss=loss, optimizer=self.optimizer)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())

if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.num_train_step += 1
step_bar.update()

# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)

def _eval(self, epoch: int):
self.model.eval()

step_bar = tqdm.trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, self.device)
logits = self.model(batch["input_ids"])["logits"]
loss = self.loss_fn(batch["labels"], logits)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add(
"loss", loss_mean.to(torch.float16).item(), count_update=batch["input_ids"].size(0)
)
step_bar.update()

loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
Original file line number Diff line number Diff line change
@@ -4,5 +4,10 @@
"stop_ids": [
2
],
"end_of_assistant": "</s>"
"end_of_assistant": "</s>",
"step_score_signal": "ки",
"reward_signal": [
"+",
"-"
]
}
Original file line number Diff line number Diff line change
@@ -1,35 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare dataset scripts
Usage:
- For SFT dataset preparation (SFT)
python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
- For prompt dataset preparation (PPO)
python prepare_dataset.py --type prompt \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
- For Preference dataset preparation (DPO and Reward model training)
python prepare_dataset.py --type preference \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
Prepare Dataset for RL Alogithm.
"""

import argparse
@@ -40,7 +12,14 @@
import time
from multiprocessing import cpu_count

from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from coati.dataset import (
setup_conversation_template,
tokenize_kto,
tokenize_process_reward,
tokenize_prompt,
tokenize_rlhf,
tokenize_sft,
)
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer

@@ -56,7 +35,7 @@ def main():
type=str,
required=True,
default=None,
choices=["sft", "prompt", "preference", "kto"],
choices=["sft", "prompt", "preference", "kto", "prm"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
)
parser.add_argument(
@@ -205,8 +184,10 @@ def main():
preparation_function = tokenize_rlhf
elif args.type == "kto":
preparation_function = tokenize_kto
elif args.type == "prm":
preparation_function = tokenize_process_reward
else:
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference', 'kto', 'prm']")

for index, dataset in enumerate(list_dataset):
assert isinstance(dataset, dataset_dict.Dataset)
@@ -218,6 +199,7 @@ def main():
dataset = dataset.select(
random.sample(range(len(dataset)), min(args.num_samples_per_datafile, len(dataset)))
)

logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
function=preparation_function,
@@ -229,13 +211,6 @@ def main():
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
if args.type == "kto":
filter_by = "completion"
elif args.type == "preference":
filter_by = "chosen_input_ids"
else:
filter_by = "input_ids"
dataset = dataset.filter(lambda data: data[filter_by] is not None)

# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)
@@ -249,22 +224,15 @@ def main():
logger.info(f"processing {count} spliced data points for {fp_writer.name}")
count += 1
fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")

logger.info(
f"Current file {fp_writer.name}; "
f"Data size: {len(dataset)}; "
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
)

# Save each arrow spliced dataset
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
logger.info(f"Start to save {output_arrow_path}")
dataset = load_dataset(
path="json",
data_files=[output_jsonl_path],
cache_dir=os.path.join(args.data_cache_dir, "tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(dataset), cpu_count()))


288 changes: 288 additions & 0 deletions applications/ColossalChat/examples/training_scripts/train_prm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""
Train Process Reward Model.
"""

import argparse
import json
import math
import os
import resource
from contextlib import nullcontext

import torch
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.trainer import ProcessRewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam


def load_dataset(args, plugin, tokenizer):
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)

train_dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)

eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)

eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)

return train_dataloader, eval_dataloader


def initialize_plugin(args):
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")

return plugin


def train(args):

colossalai.launch_from_torch()
coordinator = DistCoordinator()

init_ctx = nullcontext()
with init_ctx:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
)

plugin = initialize_plugin(args=args)
booster = Booster(plugin=plugin)
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)

model.train()
if args.grad_checkpoint:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

##
coordinator.print_on_master(
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
train_dataloader, eval_dataloader = load_dataset(args=args, plugin=plugin, tokenizer=tokenizer)

coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)

num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)

if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))

lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.max_epochs * num_update_steps_per_epoch,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)

default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, _, _ = booster.boost(model=model, optimizer=optimizer)
torch.set_default_dtype(torch.float)

coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)

start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
booster.load_model(model, args.checkpoint_path)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)

coordinator.print_on_master(
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")

coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)

trainer = ProcessRewardModelTrainer(
model=model,
booster=booster,
optimizer=optimizer,
plugin=plugin,
lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
start_epoch=start_epoch,
save_interval=args.save_interval,
save_dir=args.save_path,
coordinator=coordinator,
reward_signal_ids=tokenizer.convert_tokens_to_ids(args.reward_signal),
)

trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)


if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--microbatch_size", type=int, default=1)
parser.add_argument("--reward_signal", nargs="+", default=["+", "-"])
args = parser.parse_args()
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)