Skip to content

Commit

Permalink
support rlhf (#184)
Browse files Browse the repository at this point in the history
Co-authored-by: qidanrui <[email protected]>
Co-authored-by: junewgl <[email protected]>
Co-authored-by: wangzaistone <[email protected]>
Co-authored-by: csunny <[email protected]>
  • Loading branch information
5 people authored Dec 22, 2023
1 parent 4d48845 commit c1e1d53
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 51 deletions.
8 changes: 8 additions & 0 deletions dbgpt_hub/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ class DataArguments:
default="dbgpt_hub/data/",
metadata={"help": "The name of the folder containing datasets."},
)
cutoff_len: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."},
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."},
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."},
Expand Down
12 changes: 12 additions & 0 deletions dbgpt_hub/configs/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ class ModelArguments:
"help": "Used in rope scaling. Do not specify this argument manually."
},
)
hf_hub_token: Optional[str] = field(
default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
split_special_tokens: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not the special tokens should be split during the tokenization process."
},
)

def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
Expand Down Expand Up @@ -182,6 +191,9 @@ class FinetuningArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["sft", "rm"]] = field(
default="sft", metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora", metadata={"help": "Which fine-tuning method to use."}
)
Expand Down
9 changes: 9 additions & 0 deletions dbgpt_hub/data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,14 @@
"response": "output",
"history": "history"
}
},
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
}
101 changes: 97 additions & 4 deletions dbgpt_hub/data_process/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import pandas as pd
import tiktoken
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING, Generator
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Union,
TYPE_CHECKING,
Generator,
Literal,
)
from datasets import (
Dataset,
DatasetDict,
Expand Down Expand Up @@ -64,6 +74,17 @@ def extract_sql_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
return {"input": prompt_format.format(**example)}


def infer_max_len(
source_len: int, target_len: int, data_args: "DataArguments"
) -> Tuple[int, int]:
max_target_len = int(
data_args.cutoff_len * (target_len / (source_len + target_len))
)
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len


def local_dataset(
dataset_path: str, eval_dataset_size: float = 0.1
) -> Tuple[Dataset, Dataset]:
Expand Down Expand Up @@ -579,6 +600,7 @@ def preprocess_dataset(
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
Expand Down Expand Up @@ -670,6 +692,69 @@ def preprocess_unsupervised_dataset(

return model_inputs

def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]]
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` for rm stage
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (
isinstance(query, str)
and isinstance(response, list)
and query != ""
and len(response) > 1
):
continue

prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, query, response[0], history, system
)
_, rejected_ids = template.encode_oneturn(
tokenizer, query, response[1], history, system
)

# if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]

source_len, target_len = len(prompt_ids), max(
len(chosen_ids), len(rejected_ids)
)
max_source_len, max_target_len = infer_max_len(
source_len, target_len, data_args
)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]

model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)

return model_inputs

def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print(
"prompt:\n{}".format(
tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)
)
)
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print(
"chosen:\n{}".format(
tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)
)
)
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print(
"rejected:\n{}".format(
tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)
)
)

def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print(
Expand All @@ -690,9 +775,17 @@ def print_supervised_dataset_example(example):
)
)

dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
if stage == "pt":
pass
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
print(111111111111111111)
preprocess_function = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
pass

with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
Expand Down
119 changes: 73 additions & 46 deletions dbgpt_hub/llm_base/load_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import inspect
import math
from typing import Optional, Tuple, Dict, TYPE_CHECKING, Literal, List
from types import MethodType
Expand All @@ -10,8 +11,9 @@
from dbgpt_hub.configs.config import LAYERNORM_NAMES, VALUE_HEAD_FILE_NAME

from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.utils import check_min_version
from transformers.utils import check_min_version, cached_file
from transformers.utils.versions import require_version
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -103,32 +105,54 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return model


def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
if not os.path.exists(valuehead_file):
logger.warning(
"Provided path ({}) does not contain valuehead weights.".format(
checkpoint_dir
)
)
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer(
"default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])
)
model.register_buffer(
"default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])
def load_valuehead_params(
path_or_repo_id: str, model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir}

if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif (
"use_auth_token" in inspect.signature(cached_file).parameters
): # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")

try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))

try:
from safetensors import safe_open

vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias"),
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))

logger.warning(
"Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)
)
return True
return None


def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft",
add_valuehead: Optional[bool] = False,
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Expand All @@ -151,7 +175,8 @@ def load_model_and_tokenizer(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side=model_args.padding_side,
split_special_tokens=model_args.split_special_tokens,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)

Expand All @@ -171,6 +196,15 @@ def load_model_and_tokenizer(
else:
setattr(config, "fp16", True)

# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [
("fp16", torch.float16),
("bf16", torch.bfloat16),
("fp32", torch.float32),
]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)

# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
Expand Down Expand Up @@ -294,33 +328,26 @@ def load_model_and_tokenizer(
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)

# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = (
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = (
AutoModelForCausalLMWithValueHead.from_pretrained(model)
)
reset_logging()
if (
stage == "rm" and model_args.checkpoint_dir is not None
): # load valuehead weights to evaluate reward model
logger.warning(
"Only the last checkpoint containing valuehead will be loaded as the valuehead."
)
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict(
{
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias"),
}
)

if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(
model_args.reward_model, "reward", is_trainable=False
)
assert load_valuehead_params(
model, model_args.reward_model
), "Reward model is not correctly loaded."
ignore_modules = [
name for name, _ in model.named_parameters() if "pretrained_model" in name
]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(
model, "tie_weights", MethodType(lambda _: None, model)
) # use empty method
vhead_path = (
model_args.checkpoint_dir[-1]
if model_args.checkpoint_dir is not None
else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))

# Prepare model for inference
if not is_trainable:
Expand Down
Loading

0 comments on commit c1e1d53

Please sign in to comment.