Skip to content

Latest commit

 

History

History
51 lines (39 loc) · 2.67 KB

README.md

File metadata and controls

51 lines (39 loc) · 2.67 KB

Training GPT-2 with Explicit Attention Biases

We provide the code and pretrained checkpoints for the experiments in Section 5.2 on "Explicit attention biases". The code for training GPT-2 is based on the open-source nanoGPT repository.


We propose to augment the self-attention mechanism with explicit attention biases, by inserting auxiliary key and value parameters.

model_attn_bias.py contains the model definition of GPT-2 augmented with explicit attention biases.

Setup

  • data: Follow here to setup the training and validation data from OpenWebText2.

  • pretrained models: Here we provide the model checkpoints for three GPT-2 models we trained, each with 50k iterations

model name download path validation perplexity
default model 3.04
sink model 3.04
attn_bias model 3.04

Note: For the config files in config, set out_dir to the directory of the downloaded pretrained models and data_dir to the directories of the prepared OpenWebText2 dataset.

Evalutate

Running the following commands will evaluate the three GPT-2 checkpoints.

CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_default.py ### gpt2 default architecture
CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_sink.py ### gpt2 sink token
CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_attn_bias.py ### gpt2 attention biases

Training

Running the following commands will train the three GPT-2 models from scratch: (can adjust the number of GPUs for training on multiple GPUs)

CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_default.py ### gpt2 default architecture
CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_sink.py ### gpt2 sink token
CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_attn_bias.py ### gpt2 attention biases

Analysis

We provide the commands for visualizing the activaiton magnitudes of an intermediate feature and also layerwise largest activation magnitudes:

CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_default.py
CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_sink.py
CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_attn_bias.py