From a9b51735f58d0b4b3d73c5b777d3428d5f067242 Mon Sep 17 00:00:00 2001 From: Anke Tang Date: Tue, 3 Dec 2024 21:49:29 +0800 Subject: [PATCH] Merge pull request #45 from tanganke/develop merge develop into main --- config/dataset/llm_sft/alpaca_cleaned.yaml | 6 + config/dataset/llm_sft/ultrachat_200k.yaml | 3 + config/fabric/llama_peft_fsdp.yaml | 16 ++ config/fabric/strategy/llama_peft_fsdp.yaml | 9 + .../method/lm_finetune/peftfinetune_sft.yaml | 2 +- .../CausalLMPool/llama_ultrachat.yaml | 18 ++ .../llama_preference700k.yaml | 4 +- .../single_reward_model.yaml | 14 ++ config/taskpool/reward_model_evaluation.yaml | 18 ++ examples/lm_finetune/llama_fullfinetune.sh | 14 ++ fusion_bench/dataset/llama/collate.py | 16 +- fusion_bench/dataset/llama/preference_700k.py | 49 +++--- fusion_bench/dataset/llama/stanford_shp.py | 88 ++++++++++ fusion_bench/dataset/llama/ultrachat.py | 58 +++++++ fusion_bench/dataset/llama/utils/__init__.py | 0 .../method/lm_finetune/bradley_terry_rm.py | 40 ++--- .../method/lm_finetune/peftfinetune_sft.py | 56 ++++--- fusion_bench/mixins/fabric_training.py | 64 ++++++- fusion_bench/mixins/lightning_fabric.py | 9 + fusion_bench/models/utils.py | 8 + fusion_bench/taskpool/llama/reward_model.py | 157 ++++++++++++++++++ 21 files changed, 568 insertions(+), 81 deletions(-) create mode 100644 config/dataset/llm_sft/alpaca_cleaned.yaml create mode 100644 config/dataset/llm_sft/ultrachat_200k.yaml create mode 100644 config/fabric/llama_peft_fsdp.yaml create mode 100644 config/fabric/strategy/llama_peft_fsdp.yaml create mode 100644 config/modelpool/CausalLMPool/llama_ultrachat.yaml create mode 100644 config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml create mode 100644 config/taskpool/reward_model_evaluation.yaml create mode 100644 fusion_bench/dataset/llama/stanford_shp.py create mode 100644 fusion_bench/dataset/llama/ultrachat.py create mode 100644 fusion_bench/dataset/llama/utils/__init__.py create mode 100644 fusion_bench/taskpool/llama/reward_model.py diff --git a/config/dataset/llm_sft/alpaca_cleaned.yaml b/config/dataset/llm_sft/alpaca_cleaned.yaml new file mode 100644 index 00000000..6a765a9b --- /dev/null +++ b/config/dataset/llm_sft/alpaca_cleaned.yaml @@ -0,0 +1,6 @@ +alpaca-cleaned: + _target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset + tokenizer: ??? + path: "yahma/alpaca-cleaned" + split: train + cache_path: null diff --git a/config/dataset/llm_sft/ultrachat_200k.yaml b/config/dataset/llm_sft/ultrachat_200k.yaml new file mode 100644 index 00000000..0740d47c --- /dev/null +++ b/config/dataset/llm_sft/ultrachat_200k.yaml @@ -0,0 +1,3 @@ +ultrachat-200k: + _target_: fusion_bench.dataset.ultrachat.load_tokenized_ultrachat_200k + tokenizer: ??? diff --git a/config/fabric/llama_peft_fsdp.yaml b/config/fabric/llama_peft_fsdp.yaml new file mode 100644 index 00000000..8abd2f74 --- /dev/null +++ b/config/fabric/llama_peft_fsdp.yaml @@ -0,0 +1,16 @@ +defaults: + - loggers: tensorboard_logger + - strategy: llama_peft_fsdp + - _self_ + +_target_: lightning.Fabric +_recursive_: true +# Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. +# The value applies per node. +devices: auto +# The hardware to run on. Possible choices are: +# ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. +# for example: fabric.accelerator=cpu +accelerator: auto +# reference to the precision policy: https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision +precision: bf16-true diff --git a/config/fabric/strategy/llama_peft_fsdp.yaml b/config/fabric/strategy/llama_peft_fsdp.yaml new file mode 100644 index 00000000..a00949a0 --- /dev/null +++ b/config/fabric/strategy/llama_peft_fsdp.yaml @@ -0,0 +1,9 @@ +_target_: lightning.fabric.strategies.FSDPStrategy +sharding_strategy: FULL_SHARD +state_dict_type: full # Save a single, consolidated checkpoint file +cpu_offload: false +auto_wrap_policy: + _target_: fusion_bench.mixins.lightning_fabric.get_size_based_auto_wrap_policy +activation_checkpointing_policy: ${.auto_wrap_policy} +# limit_all_gathers: true + diff --git a/config/method/lm_finetune/peftfinetune_sft.yaml b/config/method/lm_finetune/peftfinetune_sft.yaml index 262921b3..652b3472 100644 --- a/config/method/lm_finetune/peftfinetune_sft.yaml +++ b/config/method/lm_finetune/peftfinetune_sft.yaml @@ -57,7 +57,7 @@ save_optimizer_state: false # save_full_model must be true when using shared FSDP save_full_model: false # save_ckpt_type can be 'peft' or 'lightning' -save_ckpt_type: peft +save_ckpt_type: lightning # Path to checkpoint to load from, used for resuming training ckpt_path: null max_length: 4096 diff --git a/config/modelpool/CausalLMPool/llama_ultrachat.yaml b/config/modelpool/CausalLMPool/llama_ultrachat.yaml new file mode 100644 index 00000000..d0f143a4 --- /dev/null +++ b/config/modelpool/CausalLMPool/llama_ultrachat.yaml @@ -0,0 +1,18 @@ +_target_: fusion_bench.modelpool.CausalLMPool + +pretrained_model_name_or_path: meta-llama/Llama-3-1B-Instruct + +models: + _pretrained_: + _target_: transformers.AutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: ${...pretrained_model_name_or_path} + torch_dtype: bfloat16 + +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: ${..pretrained_model_name_or_path} + +train_datasets: + ultrachat-200k: + _target_: fusion_bench.dataset.llama.ultrachat.load_tokenized_ultrachat_200k + tokenizer: ${...tokenizer} diff --git a/config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml b/config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml index a56ac138..234d6afd 100644 --- a/config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +++ b/config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml @@ -1,4 +1,4 @@ -_target_: fusion_bench.modelpool.CausalLMPool +_target_: fusion_bench.modelpool.SeqenceClassificationModelPool pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct @@ -16,7 +16,7 @@ tokenizer: train_datasets: preference_700k: - _target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_bradley_terry_rm + _target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf tokenizer: ${...tokenizer} path: hendrydong/preference_700K split: train diff --git a/config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml b/config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml new file mode 100644 index 00000000..47089d6c --- /dev/null +++ b/config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml @@ -0,0 +1,14 @@ +_target_: fusion_bench.modelpool.SeqenceClassificationModelPool + +pretrained_model_name_or_path: fusion-bench/Llama-3.2-1B-Instruct_Bradly-Terry-RM_Preference-700k + +models: + _pretrained_: + _target_: transformers.AutoModelForSequenceClassification.from_pretrained + pretrained_model_name_or_path: ${...pretrained_model_name_or_path} + torch_dtype: bfloat16 + +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: ${..pretrained_model_name_or_path} + pad_token: <|end_of_text|> # do not use eos token (<|eos_id|>) as padding token because it is used as the end of each content diff --git a/config/taskpool/reward_model_evaluation.yaml b/config/taskpool/reward_model_evaluation.yaml new file mode 100644 index 00000000..bedafc6d --- /dev/null +++ b/config/taskpool/reward_model_evaluation.yaml @@ -0,0 +1,18 @@ +_target_: fusion_bench.taskpool.llama.reward_model.RewardModelEvaluationTaskPool + +test_datasets: + preference_700k: + _target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf + tokenizer: ${...tokenizer} + path: hendrydong/preference_700K + split: train + cache_path: null + +dataloader_kwargs: + shuffle: False + batch_size: 16 + +tokenizer: ${..modelpool.tokenizer} + +max_num_samples: 1000 +seed: 42 diff --git a/examples/lm_finetune/llama_fullfinetune.sh b/examples/lm_finetune/llama_fullfinetune.sh index b92c11e4..5f37dbdf 100644 --- a/examples/lm_finetune/llama_fullfinetune.sh +++ b/examples/lm_finetune/llama_fullfinetune.sh @@ -21,3 +21,17 @@ fusion_bench --config-name llama_full_finetune \ method.checkpoint_save_frequency=2000 \ method.max_epochs=1 \ modelpool=CausalLMPool/llama_metamathqa + +# full finetune on ultrachat +fusion_bench --config-name llama_full_finetune \ + fabric=llama_peft_fsdp \ + fabric.loggers.name=llama_lora_finetune \ + method=lm_finetune/peftfinetune_sft \ + method.dataloader_kwargs.batch_size=1 \ + method.max_epochs=1 \ + method.gradient_clip_val=1.0 \ + method.accumulate_grad_batches=16 \ + method.checkpoint_save_interval=step \ + method.checkpoint_save_frequency=2000 \ + modelpool=CausalLMPool/llama_ultrachat \ + modelpool.pretrained_model_name_or_path=meta-llama/Meta-Llama-3-8B-Instruct diff --git a/fusion_bench/dataset/llama/collate.py b/fusion_bench/dataset/llama/collate.py index 389e8fe2..13b9affd 100644 --- a/fusion_bench/dataset/llama/collate.py +++ b/fusion_bench/dataset/llama/collate.py @@ -84,14 +84,14 @@ def bradley_terry_rm_collate( converted_batch = [] for item in batch: new_item = { - "input_ids": item["input_ids_j"], - "attention_mask": item["attention_mask_j"], + "input_ids": item["chosen_input_ids"], + "attention_mask": item["chosen_attention_mask"], } converted_batch.append(new_item) for item in batch: new_item = { - "input_ids": item["input_ids_k"], - "attention_mask": item["attention_mask_k"], + "input_ids": item["rejected_input_ids"], + "attention_mask": item["rejected_attention_mask"], } converted_batch.append(new_item) @@ -111,10 +111,10 @@ def bradley_terry_rm_collate( collated_batch = {"input_ids": input_ids, "attention_mask": attention_mask} for key in batch[0]: if key not in [ - "input_ids_j", - "attention_mask_j", - "input_ids_k", - "attention_mask_k", + "chosen_input_ids", + "chosen_attention_mask", + "rejected_input_ids", + "rejected_attention_mask", ]: collated_batch[key] = [x[key] for x in batch] return collated_batch diff --git a/fusion_bench/dataset/llama/preference_700k.py b/fusion_bench/dataset/llama/preference_700k.py index 54d4cbee..cb24d874 100644 --- a/fusion_bench/dataset/llama/preference_700k.py +++ b/fusion_bench/dataset/llama/preference_700k.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy from typing import TYPE_CHECKING, Optional from datasets import Dataset, load_dataset, load_from_disk @@ -6,14 +7,15 @@ from tqdm.auto import tqdm from fusion_bench.utils import timeit_context - -from .alpaca import convert_alpaca_to_conversation +import logging if TYPE_CHECKING: from transformers import PreTrainedTokenizer +log = logging.getLogger(__name__) + -def load_tokenized_preference_700k_for_bradley_terry_rm( +def load_tokenized_preference_700k_for_rlhf( tokenizer: "PreTrainedTokenizer", path: str = "hendrydong/preference_700K", split: str = "train", @@ -25,10 +27,10 @@ def load_tokenized_preference_700k_for_bradley_terry_rm( The returned dataset contains the following fields: - - input_ids_j: The input token ids for the winner. - - attention_mask_j: The attention mask for the winner. - - input_ids_k: The input token ids for the loser. - - attention_mask_k: The attention mask for the loser. + - chosen_input_ids: The input token ids for the winner. + - chosen_attention_mask: The attention mask for the winner. + - rejected_input_ids: The input token ids for the loser. + - rejected_attention_mask: The attention mask for the loser. """ if cache_path is not None and os.path.exists(cache_path): dataset = load_from_disk(cache_path) @@ -37,21 +39,28 @@ def load_tokenized_preference_700k_for_bradley_terry_rm( dataset = load_dataset(path, split=split) def tokenize(sample): - - # ? is it necessary to `.replace(tokenizer.bos_token, "")`? - sample["positive"] = tokenizer.apply_chat_template( + sample["chosen_chat"] = tokenizer.apply_chat_template( sample["chosen"], tokenize=False, add_generation_prompt=False - ).replace(tokenizer.bos_token, "") - sample["negative"] = tokenizer.apply_chat_template( + ) + sample["rejected_chat"] = tokenizer.apply_chat_template( sample["rejected"], tokenize=False, add_generation_prompt=False - ).replace(tokenizer.bos_token, "") - - tokenized_pos = tokenizer(sample["positive"], truncation=True) - tokenized_neg = tokenizer(sample["negative"], truncation=True) - sample["input_ids_j"] = tokenized_pos["input_ids"] - sample["attention_mask_j"] = tokenized_pos["attention_mask"] - sample["input_ids_k"] = tokenized_neg["input_ids"] - sample["attention_mask_k"] = tokenized_neg["attention_mask"] + ) + + tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True) + tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True) + + # Ensure that the chosen response does not contain an PAD token + sample["chosen_input_ids"] = tokenized_pos["input_ids"] + sample["chosen_attention_mask"] = tokenized_pos["attention_mask"] + if tokenizer.pad_token_id in tokenized_pos["input_ids"]: + log.warning(f"Prompt contains PAD token: {sample['chosen_chat']}") + + sample["rejected_input_ids"] = tokenized_neg["input_ids"] + sample["rejected_attention_mask"] = tokenized_neg["attention_mask"] + # Ensure that the rejected response does not contain an PAD token + if tokenizer.pad_token_id in tokenized_neg["input_ids"]: + log.warning(f"Prompt contains PAD token: {sample['rejected_chat']}") + return sample dataset = dataset.map(tokenize, num_proc=num_proc) diff --git a/fusion_bench/dataset/llama/stanford_shp.py b/fusion_bench/dataset/llama/stanford_shp.py new file mode 100644 index 00000000..f829abba --- /dev/null +++ b/fusion_bench/dataset/llama/stanford_shp.py @@ -0,0 +1,88 @@ +import os +from copy import deepcopy +from typing import TYPE_CHECKING, Optional + +from datasets import Dataset, load_dataset, load_from_disk +from lightning.fabric.utilities import rank_zero_only +from tqdm.auto import tqdm + +from fusion_bench.utils import timeit_context + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +def load_tokenized_stanford_shp_for_rlhf( + tokenizer: "PreTrainedTokenizer", + path: str = "stanfordnlp/SHP", + split: str = "train", + num_proc: int = 8, + cache_path: Optional[str] = None, +): + if cache_path is not None and os.path.isdir(cache_path): + dataset = load_from_disk(cache_path) + return dataset + + dataset = load_dataset(path, split=split) + + def tokenize(sample): + """ + - history: the post title concatented to the post body (string) + - human_ref_A: text of comment A (string) + - human_ref_B: text of comment B (string) + - labels: the preference label -- it is 1 if A is preferred to B; 0 if B is preferred to A. This was randomized such that the label distribution is roughly 50/50. (integer) + """ + # Create a conversation with the post title and body, followed by comments + conversation = [{"role": "user", "content": sample["history"]}] + if sample["labels"] == 0: + sample["chosen"] = deepcopy(conversation).append( + {"role": "assistant", "content": sample["human_ref_B"]} + ) + sample["rejected"] = deepcopy(conversation).append( + {"role": "assistant", "content": sample["human_ref_A"]} + ) + else: + sample["chosen"] = deepcopy(conversation).append( + {"role": "assistant", "content": sample["human_ref_A"]} + ) + sample["rejected"] = deepcopy(conversation).append( + {"role": "assistant", "content": sample["human_ref_B"]} + ) + + # apply chat template + sample["chosen_chat"] = tokenizer.apply_chat_template( + sample["chosen"], tokenize=False, add_generation_prompt=False + ) + sample["rejected_chat"] = tokenizer.apply_chat_template( + sample["rejected"], tokenize=False, add_generation_prompt=False + ) + + # tokenize the conversation + tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True) + tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True) + + # Ensure that the chosen response does not contain an EOS token + sample["chosen_input_ids"] = tokenized_pos["input_ids"] + sample["chosen_attention_mask"] = tokenized_pos["attention_mask"] + assert ( + tokenizer.eos_token_id not in tokenized_pos["input_ids"][:-1] + ), f"Prompt contains EOS token: {sample['positive']}" + if sample["chosen_input_ids"][-1] != tokenizer.eos_token_id: + sample["chosen_input_ids"].append(tokenizer.eos_token_id) + sample["chosen_attention_mask"].append(1) + + sample["rejected_input_ids"] = tokenized_neg["input_ids"] + sample["rejected_attention_mask"] = tokenized_neg["attention_mask"] + # Ensure that the rejected response does not contain an EOS token + assert ( + tokenizer.eos_token_id not in tokenized_neg["input_ids"][:-1] + ), f"Prompt contains EOS token: {sample['rejected']}" + if sample["rejected_input_ids"][-1] != tokenizer.eos_token_id: + sample["rejected_input_ids"].append(tokenizer.eos_token_id) + sample["rejected_attention_mask"].append(1) + + dataset = dataset.map(tokenize, num_proc=num_proc) + + if cache_path is not None and rank_zero_only.rank == 0: + dataset.save_to_disk(cache_path) + return dataset diff --git a/fusion_bench/dataset/llama/ultrachat.py b/fusion_bench/dataset/llama/ultrachat.py new file mode 100644 index 00000000..680f9284 --- /dev/null +++ b/fusion_bench/dataset/llama/ultrachat.py @@ -0,0 +1,58 @@ +import os +from typing import TYPE_CHECKING, Optional + +from datasets import Dataset, load_dataset, load_from_disk +from lightning.fabric.utilities import rank_zero_only +from tqdm.auto import tqdm + +from fusion_bench.utils import timeit_context + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +def load_tokenized_ultrachat_200k( + tokenizer: "PreTrainedTokenizer", + path: str = "HuggingFaceH4/ultrachat_200k", + split: str = "train_sft", + num_proc: int = 8, + cache_path: Optional[str] = None, +): + R""" + Load and tokenized Ultrachat 200k dataset for Bradley-Terry ranking model. + + The returned dataset contains the following fields: + + - input_ids: The input token ids for the winner. + - attention_mask: The attention mask for the winner. + """ + if cache_path is not None and os.path.exists(cache_path): + dataset = load_from_disk(cache_path) + return dataset + + dataset = load_dataset(path, split=split) + + def tokenize(sample): + + # ? is it necessary to `.replace(tokenizer.bos_token, "")`? + sample["input_ids"] = tokenizer.apply_chat_template( + sample["messages"], tokenize=True, add_generation_prompt=False + ) + sample["attention_mask"] = [1] * len(sample["input_ids"]) + + return sample + + dataset = dataset.map(tokenize, num_proc=num_proc) + + if cache_path is not None and rank_zero_only.rank == 0: + dataset.save_to_disk(cache_path) + return dataset + + +if __name__ == "__main__": + # Example usage and testing + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + dataset = load_tokenized_ultrachat_200k(tokenizer) + print(dataset) diff --git a/fusion_bench/dataset/llama/utils/__init__.py b/fusion_bench/dataset/llama/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fusion_bench/method/lm_finetune/bradley_terry_rm.py b/fusion_bench/method/lm_finetune/bradley_terry_rm.py index 39779276..0b9fc981 100644 --- a/fusion_bench/method/lm_finetune/bradley_terry_rm.py +++ b/fusion_bench/method/lm_finetune/bradley_terry_rm.py @@ -3,10 +3,10 @@ The dataset contains the following fields: -- input_ids_j: The input token ids for the winner. -- attention_mask_j: The attention mask for the winner. -- input_ids_k: The input token ids for the loser. -- attention_mask_k: The attention mask for the loser. +- chosen_input_ids: The input token ids for the winner. +- chosen_attention_mask: The attention mask for the winner. +- rejected_input_ids: The input token ids for the loser. +- rejected_attention_mask: The attention mask for the loser. """ @@ -232,13 +232,13 @@ def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor ) rewards = outputs[0] - rewards_j = rewards[: batch_size // 2] - rewards_k = rewards[batch_size // 2 :] - loss = -torch.log(torch.sigmoid(rewards_j - rewards_k)).mean() + chosen_reward = rewards[: batch_size // 2] + rejected_rewards = rewards[batch_size // 2 :] + loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean() return { - "reward_j": rewards_j, - "reward_k": rewards_k, + "chosen_reward": chosen_reward, + "rejected_reward": rejected_rewards, "loss": loss, } @@ -247,8 +247,8 @@ def train_epoch(self, *args, **kwargs): fabric = self.fabric accumulated_loss = 0 - accumulated_reward_j = 0 - accumulated_reward_k = 0 + accumulated_chosen_reward = 0 + accumulated_rejected_reward = 0 for step_idx, batch in enumerate( pbar := tqdm( self.train_dataloader, @@ -276,8 +276,8 @@ def train_epoch(self, *args, **kwargs): fabric.backward(loss) accumulated_loss += loss.item() - accumulated_reward_j += output["reward_j"].mean().item() - accumulated_reward_k += output["reward_k"].mean().item() + accumulated_chosen_reward += output["chosen_reward"].mean().item() + accumulated_rejected_reward += output["rejected_reward"].mean().item() # 1. update the model parameters if not accumulating gradients # 2. step the lr_scheduler if interval is set to "step" and frequency is met @@ -300,15 +300,15 @@ def train_epoch(self, *args, **kwargs): metrics = { "train/loss": accumulated_loss, - "train/reward_j": accumulated_reward_j + "train/chosen_reward": accumulated_chosen_reward / self.accumulate_grad_batches, - "train/reward_k": accumulated_reward_k + "train/rejected_reward": accumulated_rejected_reward / self.accumulate_grad_batches, "train/epoch_idx": self.epoch_idx, "train/lr": self.optimizer.param_groups[0]["lr"], } - metrics["train/reward_j-k"] = ( - metrics["train/reward_j"] - metrics["train/reward_k"] + metrics["train/chosen_reward-rejected_reward"] = ( + metrics["train/chosen_reward"] - metrics["train/rejected_reward"] ) fabric.log_dict(metrics, step=self.global_step_idx) @@ -330,8 +330,8 @@ def train_epoch(self, *args, **kwargs): self.global_step_idx += 1 accumulated_loss = 0 - accumulated_reward_j = 0 - accumulated_reward_k = 0 + accumulated_chosen_reward = 0 + accumulated_rejected_reward = 0 def save_checkpoint( self, @@ -425,7 +425,7 @@ def load_checkpoint( tokenizer.save_pretrained(args.output_path) model = AutoModelForSequenceClassification.from_pretrained( - args.base_model_path, torch_dtype=torch.bfloat16 + args.base_model_path, num_labels=1, torch_dtype=torch.bfloat16 ) model = fabric.setup_module(model) load_checkpoint(fabric, args.ckpt_path, model=model, strict=True) diff --git a/fusion_bench/method/lm_finetune/peftfinetune_sft.py b/fusion_bench/method/lm_finetune/peftfinetune_sft.py index d0fa8e02..78e4efa2 100644 --- a/fusion_bench/method/lm_finetune/peftfinetune_sft.py +++ b/fusion_bench/method/lm_finetune/peftfinetune_sft.py @@ -12,7 +12,7 @@ from lightning.fabric.strategies.fsdp import FSDPStrategy from lightning.fabric.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf -from peft import PeftModel, get_peft_config, get_peft_model +from peft import LoraConfig, PeftModel, get_peft_config, get_peft_model from torch import nn from torch.utils.data import ConcatDataset, DataLoader from tqdm.auto import tqdm @@ -129,6 +129,7 @@ def run(self, modelpool: CausalLMPool): return self.model def setup_model(self): + # https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py self.tokenizer = self.modelpool.load_tokenizer() if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id @@ -137,6 +138,7 @@ def setup_model(self): # get the PEFT model peft_config = instantiate(self._peft_config, _convert_="all") + peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config")) peft_model = get_peft_model(model, peft_config, self.adapter_name) peft_model.print_trainable_parameters() @@ -154,6 +156,7 @@ def setup_model(self): self.use_cache = True self.model_dtype = get_dtype(self.model) + self.model = self.model.to(dtype=self.model_dtype) def configure_optimizer(self): # compute expected total steps @@ -211,11 +214,12 @@ def setup(self): optimizer = self.configure_optimizer() optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"] - self.model, self.optimizer = fabric.setup(self.model, optimizer) + self.model = self.fabric.setup_module(self.model) + self.optimizer = self.fabric.setup_optimizers(optimizer) self.lr_scheduler = lr_scheduler @override - def train_epoch(self): + def train_epoch(self, *args, **kwargs): fabric = self.fabric accumulated_loss = 0 @@ -252,14 +256,6 @@ def train_epoch(self): fabric.backward(loss) accumulated_loss += loss.item() - metrics = { - "train/loss": accumulated_loss, - "train/epoch_idx": self.epoch_idx, - "train/lr": self.optimizer.param_groups[0]["lr"], - } - fabric.log_dict(metrics, step=self.global_step_idx) - pbar.set_postfix(metrics) - if not is_accumulating: self.clip_gradients_if_needed(self.model, self.optimizer) @@ -274,22 +270,30 @@ def train_epoch(self): self.optimizer.step() self.optimizer.zero_grad() - # save the model at the end of the step if interval is set to "step" and frequency is met - self.conditional_checkpoint_save(stage="end_of_step") + metrics = { + "train/loss": accumulated_loss, + "train/epoch_idx": self.epoch_idx, + "train/lr": self.optimizer.param_groups[0]["lr"], + } + fabric.log_dict(metrics, step=self.global_step_idx) + pbar.set_postfix(metrics) + + # save the model at the end of the step if interval is set to "step" and frequency is met + self.conditional_checkpoint_save(stage="end_of_step") - # break if max_steps_per_epoch is set, and exit epoch - if ( - self.max_steps_per_epoch > 0 - and step_idx + 1 >= self.max_steps_per_epoch - ): - break - # break if max_steps is set, and exit training - if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1: - self.is_training = False - break + # break if max_steps_per_epoch is set, and exit epoch + if ( + self.max_steps_per_epoch > 0 + and step_idx + 1 >= self.max_steps_per_epoch + ): + break + # break if max_steps is set, and exit training + if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1: + self.is_training = False + break - self.global_step_idx += 1 - accumulated_loss = 0 + self.global_step_idx += 1 + accumulated_loss = 0 def save_checkpoint( self, @@ -324,7 +328,7 @@ def save_checkpoint( if self.save_full_model else {"model": lambda k, p: k in trainable_param_names} ) - + os.makedirs(os.path.dirname(path), exist_ok=True) fabric.save(path, state=state, filter=filter) elif self.save_ckpt_type == "peft": self.model.save_pretrained(path, is_main_process=fabric.is_global_zero) diff --git a/fusion_bench/mixins/fabric_training.py b/fusion_bench/mixins/fabric_training.py index d96c739a..d7521aa3 100644 --- a/fusion_bench/mixins/fabric_training.py +++ b/fusion_bench/mixins/fabric_training.py @@ -21,6 +21,9 @@ class FabricTrainingMixin(LightningFabricMixin): + """ + This is a general purpose mixin for training a model with PyTorch Lightning. + """ _latest_saved_checkpoint_global_step: int = -1 """The global step index of the latest saved checkpoint.""" @@ -54,6 +57,13 @@ class FabricTrainingMixin(LightningFabricMixin): """The frequency to save the model checkpoint.""" def clip_gradients_if_needed(self, model, optimizer): + """ + Clips gradients if the gradient clipping value is set. + + Args: + model (nn.Module): The model whose gradients need to be clipped. + optimizer (torch.optim.Optimizer): The optimizer used for training. + """ fabric = self.fabric if self.gradient_clip_val is not None: @@ -69,6 +79,12 @@ def clip_gradients_if_needed(self, model, optimizer): def compute_expected_total_steps( self, train_dataloader: torch.utils.data.DataLoader ): + """ + Computes the expected total number of steps for the entire training. + + Args: + train_dataloader (torch.utils.data.DataLoader): The dataloader for the training data. + """ # compute expected total steps self._expected_total_steps = [] if self.max_steps > 0: @@ -86,7 +102,12 @@ def compute_expected_total_steps( @property def expected_total_steps(self): - """The expected total number of steps of the entire training. You need to run `compute_expected_total_steps` method to compute this value before accessing it.""" + """ + The expected total number of steps of the entire training. You need to run `compute_expected_total_steps` method to compute this value before accessing it. + + Raises: + ValueError: If the expected total steps have not been computed. + """ if self._expected_total_steps is None: raise ValueError( "The expected total steps have not been computed. Run `compute_expected_total_steps` method." @@ -100,6 +121,12 @@ def conditional_checkpoint_save( *args, **kwargs, ): + """ + Conditionally saves a checkpoint based on the current training stage. + + Args: + stage (Literal["end_of_step", "end_of_epoch", "end_of_training"]): The current stage of training. + """ if stage == "end_of_step": if ( self.checkpoint_save_interval == "step" @@ -129,9 +156,11 @@ def conditional_checkpoint_save( self.save_checkpoint(save_path, *args, **kwargs) try: os.symlink( - save_path, - os.path.join(self.log_dir, "checkpoints", "latest_model.ckpt"), - os.path.isdir(save_path), + src=save_path, + dst=os.path.join( + self.log_dir, "checkpoints", "latest_model.ckpt" + ), + target_is_directory=os.path.isdir(save_path), ) except Exception as e: log.error(f"Failed to create symlink: {e}") @@ -142,6 +171,15 @@ def conditional_checkpoint_save( @abstractmethod def save_checkpoint(self, path, **kwargs): + """ + Saves a checkpoint of the model. + + Args: + path (str): The path where the checkpoint will be saved. + + Raises: + NotImplementedError: If the method is not implemented. + """ raise NotImplementedError("save_checkpoint method is not implemented") def train( @@ -151,7 +189,14 @@ def train( lr_scheduler: torch.optim.lr_scheduler.LRScheduler, ): """ + Trains the model. + The global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`. + + Args: + model (Union[nn.Module, "_FabricModule"]): The model to be trained. + optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training. + lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler. """ fabric = self.fabric self.is_training = True @@ -193,6 +238,17 @@ def train_epoch( optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"], lr_scheduler: torch.optim.lr_scheduler.LRScheduler, ): + """ + Trains the model for one epoch. + + Args: + model (Union[nn.Module, "_FabricModule"]): The model to be trained. + optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training. + lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler. + + Raises: + NotImplementedError: If the method is not implemented. + """ raise NotImplementedError( "Copy this as a template and implement your own train_epoch method" ) diff --git a/fusion_bench/mixins/lightning_fabric.py b/fusion_bench/mixins/lightning_fabric.py index f1feb763..5127ce5b 100644 --- a/fusion_bench/mixins/lightning_fabric.py +++ b/fusion_bench/mixins/lightning_fabric.py @@ -1,3 +1,4 @@ +import functools import logging import os from typing import TYPE_CHECKING, Any, List, Optional, TypeVar @@ -13,6 +14,7 @@ if TYPE_CHECKING: import lightning.fabric.loggers.tensorboard + from lightning.fabric.strategies import FSDPStrategy log = logging.getLogger(__name__) @@ -32,6 +34,13 @@ def get_policy(*args: str) -> set: return {import_object(arg) for arg in args} +def get_size_based_auto_wrap_policy(*args, **kwargs): + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + + policy = functools.partial(size_based_auto_wrap_policy, *args, **kwargs) + return policy + + class LightningFabricMixin: """ A mixin class for integrating Lightning Fabric into a project. diff --git a/fusion_bench/models/utils.py b/fusion_bench/models/utils.py index 8263aaf2..08bab2b4 100644 --- a/fusion_bench/models/utils.py +++ b/fusion_bench/models/utils.py @@ -1,5 +1,6 @@ from typing import List +import torch from torch import nn @@ -70,3 +71,10 @@ def find_layers_with_type( if isinstance(submodule, tuple(layer_types)): res[name] = submodule return res + + +def disable_dropout(model: torch.nn.Module): + """Disable dropout in a model.""" + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 diff --git a/fusion_bench/taskpool/llama/reward_model.py b/fusion_bench/taskpool/llama/reward_model.py new file mode 100644 index 00000000..13dfefab --- /dev/null +++ b/fusion_bench/taskpool/llama/reward_model.py @@ -0,0 +1,157 @@ +""" +The dataset contains the following fields: + +- chosen_input_ids: The input token ids for the winner. +- chosen_attention_mask: The attention mask for the winner. +- rejected_input_ids: The input token ids for the loser. +- rejected_attention_mask: The attention mask for the loser. +""" + +import functools +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast + +import lightning as L +import torch +from omegaconf import DictConfig +from torch.utils.data import Subset +import numpy as np +from tqdm.auto import tqdm + +from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate +from fusion_bench.mixins import LightningFabricMixin +from fusion_bench.taskpool import BaseTaskPool +from fusion_bench.utils import instantiate + +if TYPE_CHECKING: + from transformers import LlamaForSequenceClassification + + +def evaluate_batch(model: "LlamaForSequenceClassification", batch): + batch_size = batch["input_ids"].size(0) + assert batch_size % 2 == 0, "Batch size must be even." + + outputs = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ) + + rewards = outputs[0] + chosen_reward = rewards[: batch_size // 2] + rejected_rewards = rewards[batch_size // 2 :] + + loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean() + correct = (chosen_reward > rejected_rewards).sum().item() + total = batch_size // 2 + + return { + "loss": loss.item(), + "correct": correct, + "total": total, + } + + +def evaluate_dataloader(model: "LlamaForSequenceClassification", dataloader): + """ + Compute the accuracy of the reward model on the given dataloader. + + Args: + model: The reward model + dataloader: The dataloader for the dataset + + Returns: + float: The accuracy of the reward model on the dataset + """ + metrics = { + "loss": 0.0, + "correct": 0, + "total": 0, + } + with torch.no_grad(): + for batch in (pbar := tqdm(dataloader)): + batch_result = evaluate_batch(model, batch) + new_total = metrics["total"] + batch_result["total"] + metrics["loss"] = ( + metrics["loss"] * metrics["total"] / new_total + + batch_result["loss"] * batch_result["total"] / new_total + ) + metrics["correct"] += batch_result["correct"] + metrics["total"] += batch_result["total"] + pbar.set_postfix(metrics) + + metrics["accuracy"] = metrics["correct"] / metrics["total"] + return metrics + + +class RewardModelEvaluationTaskPool( + BaseTaskPool, + LightningFabricMixin, +): + def __init__( + self, + test_datasets: List[DictConfig], + dataloader_kwargs: DictConfig, + tokenizer: Optional[DictConfig], + max_num_samples: int = -1, + seed: int = 0, + **kwargs, + ): + self.seed = seed + L.seed_everything(seed) + self._test_datasets = test_datasets + self.dataloader_kwargs = dataloader_kwargs + self._tokenizer = tokenizer + self.max_num_samples = max_num_samples + super().__init__(**kwargs) + + def setup(self): + if self._tokenizer is None: + # try to load the tokenizer from the model pool + tokenizer = self._program.modelpool.load_tokenizer() + else: + tokenizer = instantiate(self._tokenizer) + self.tokenizer = tokenizer + + test_datasets = { + dataset_name: instantiate(self._test_datasets[dataset_name]) + for dataset_name in self._test_datasets + } + if self.max_num_samples > 0: + test_datasets = { + dataset_name: Subset( + test_dataset, + np.random.permutation(len(test_dataset))[: self.max_num_samples], + ) + for dataset_name, test_dataset in test_datasets.items() + } + test_dataloaders = { + dataset_name: torch.utils.data.DataLoader( + test_dataset, + collate_fn=functools.partial( + bradley_terry_rm_collate, + pad_token_id=tokenizer.pad_token_id, + ), + **self.dataloader_kwargs, + ) + for dataset_name, test_dataset in test_datasets.items() + } + + self.test_dataloaders = { + dataset_name: self.fabric.setup_dataloaders(test_dataloader) + for dataset_name, test_dataloader in test_dataloaders.items() + } + + @torch.no_grad() + def evaluate(self, model: "LlamaForSequenceClassification"): + self.setup() + + model = self.fabric.setup_module(model) + if model.config.pad_token_id is None: + model.config.pad_token_id = self.tokenizer.pad_token_id + + model.eval() + report = {} + for dataset_name, test_dataloader in self.test_dataloaders.items(): + report[dataset_name] = evaluate_dataloader(model, test_dataloader) + + print(report) + return report