This codebase is an implementation of a vanilla RLAIF pipeline, utilizing GPT-2-Large (774M) for the summarization task with the TL;DR dataset. LLama-3.3-70B is employed to construct the preference dataset, while GPT-4o is used to calculate the win rate. The pipeline is primarily based on the approach described in the paper "RLAIF vs. RLHF: Scaling Reinforcement Learning from Human Feedback with AI Feedback" (https://arxiv.org/abs/2309.00267).
This implementation is mainly intended for research and educational purposes. We note that our evaluation results are not directly comparable to those reported in this paper due to differences in the base models used and the potential variation in the size of the AI-annotated preference dataset.
git clone [email protected]:mengdi-li/vanilla-RLAIF-pipeline.git
cd vanilla-RLAIF-pipeline
conda create -n rlaif python=3.12.7
conda activate rlaif
pip install -r requirements.txt
Use default parameters in each script to reproduce our reported results.
Implemented in python script ./rlaif/gpt2/summarization/tldr/sft.py
. See detailed running commands in the python script.
The following two steps can be integrated later. We use seperated steps for easier and more efficient debugging.
-
Filter out unique posts from the original openai preference dataset
- Implemented in python script
./rlaif/gpt2/summarization/tldr/preprocess_openai_tldr_human_feedback_dataset.py
. See detailed running commands in the python script.
- Implemented in python script
-
Generate a summary dataset using the SFT model
- Implemented in python script
./rlaif/gpt2/summarization/tldr/build_ai_feedback_preference_dataset_vllm.py
. See detailed running commands in the python script.
- Implemented in python script
-
Generate preference dataset using an off-the-shelf model, e.g, Llama-3.3-70B-Instruct, gpt-4o
- Implemented in python script
./rlaif/gpt2/summarization/tldr/build_ai_feedback_preference_dataset_vllm.py
. See detailed running commands in the python script.
- Implemented in python script
-
Analyse position bias in the constructed preference dataset
- Implemented in python script
./rlaif/gpt2/summarization/tldr/analyse_aif_positional_bias.py
. See detailed running commands in the python script.
- Implemented in python script
-
Preprocess preference data
- Implemented in python script
rlaif/gpt2/summarization/tldr/preprocess_preference_dataset_for_rm_training.py
. See detailed running commands in the python script.
- Implemented in python script
-
Training
- Implemented in python script
rlaif/gpt2/summarization/tldr/train_rm.py
. See detailed running commands in the python script. - Note: To train the reward model using soft labels, we need to modify the TRL library. We observed that training with soft labels results in smoother training curves compared to training with hard labels; however, their final accuracies are similar.
- Implemented in python script
-
Evaluate reward model
- Implemented in python script
rlaif/gpt2/summarization/tldr/evaluate_rm_acc.py
. See detailed running commands in the python script. - Accuracy on ai-annotated preference dataset: 72.9%
- Accuracy on human-annotated preference dataset (with a confidence threshold of 8): 61.0%
- Implemented in python script
Implemented in python script rlaif/gpt2/summarization/tldr/train_policy.py
. See detailed running commands in the python script.
- Evaluate the win rate of the policy
- Implemented in python script
rlaif/gpt2/summarization/tldr/evaluate_win_rate.py
. See detailed running commands in the python script. - Win rate
- rlaif policy v.s. sft policy: 68.1%
- rlaif policy v.s. human annotations: 47.6%
- Implemented in python script
In this part, we finetune the model obtained from SFT using RLHF.
Implemented in python script ./rlaif/gpt2/summarization/tldr/train_rm_rlhf.py
. See detailed running commands in the python script.
- Evaluate reward model
- Implemented in python script
rlaif/gpt2/summarization/tldr/evaluate_rm_acc.py
. See detailed running commands in the python script. - Accuracy on human-annotated preference dataset (with a confidence threshold of 8): 69.3%
- Implemented in python script
Implemented in python script ./rlaif/gpt2/summarization/tldr/train_policy_rlhf.py
. See detailed running commands in the python script.
- Evaluate the win rate of the policy
- Implemented in python script
rlaif/gpt2/summarization/tldr/evaluate_win_rate.py
. See detailed running commands in the python script. - Win rate
- rlaif policy v.s. sft policy: 54.9% (need further tuning)
- rlaif policy v.s. human annotations: 35.1% (need further tuning)
- Implemented in python script