-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added code for EMNLP paper on in-context demonstration selection with…
… cross-entropy difference
- Loading branch information
1 parent
29fe35c
commit 64f7644
Showing
20 changed files
with
361 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "ced_icl/t-few"] | ||
path = ced_icl/t-few | ||
url = https://github.com/r-three/t-few.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# In-Context Demonstration Selection with Cross Entropy Difference | ||
|
||
[Arxiv Link to Paper](https://arxiv.org/pdf/2305.14726.pdf) | ||
|
||
Large language models (LLMs) can use in-context demonstrations to improve performance on zero-shot tasks. | ||
However, selecting the best in-context examples is challenging because model performance can vary widely depending on the selected examples. | ||
We present a cross-entropy difference (CED) method for selecting in-context demonstrations. | ||
Our method is based on the observation that the effectiveness of in-context demonstrations negatively | ||
correlates with the perplexity of the test example by a language model that was finetuned | ||
on that demonstration. | ||
We utilize parameter efficient finetuning to train small models on training data that are used for computing the cross-entropy difference between a test example and every candidate in-context demonstration. | ||
This metric is used to rank and select in-context demonstrations independently for each | ||
test input. | ||
We evaluate our method on a mix-domain dataset that combines 8 benchmarks, representing 4 text generation tasks, showing that CED for in-context demonstration selection can improve performance for a variety of LLMs. | ||
|
||
## Set up | ||
Create a conda workspace named `cdsicd`. Use the requirements.txt to download dependencies. | ||
|
||
Commands for experimental setup are in runner.sh | ||
|
||
This code base uses the [T-Few Module](https://github.com/r-three/t-few). | ||
|
||
## Citation | ||
|
||
``` | ||
@inproceedings{iter2023ced, | ||
title={In-Context Demonstration Selection with Cross Entropy Difference}, | ||
author={Dan Iter and Reid Pryzant and Ruochen Xu and Shuohang Wang and Yang Liu and Yichong Xu and Chenguang Zhu}, | ||
year={2023}, | ||
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing Findings", | ||
publisher = "Association for Computational Linguistics", | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
|
||
GPU=$1 | ||
DS=$2 | ||
START=$3 | ||
END=$4 | ||
for i in $(seq $START $END) | ||
do | ||
CUDA_VISIBLE_DEVICES=${GPU} conda run -n cdsicd bash bin/unifiedqa_ppl_ft.sh ${DS}_$i | ||
done | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
|
||
GPU=$1 | ||
DS=$2 | ||
START=$3 | ||
END=$4 | ||
for i in $(seq $START $END) | ||
do | ||
CUDA_VISIBLE_DEVICES=${GPU} conda run -n cdsicd bash bin/unifiedqa_ppl_ft_clusters.sh ${DS}_$i | ||
done | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
|
||
GPU=$1 | ||
DS=$2 | ||
START=$3 | ||
END=$4 | ||
for i in $(seq $START $END) | ||
do | ||
CUDA_VISIBLE_DEVICES=${GPU} conda run -n cdsicd bash bin/unifiedqa_ppl_ft_train.sh ${DS}_$i | ||
done | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
|
||
#$AMLT_CODE_DIR can be used too | ||
|
||
GPU=$1 | ||
DS=$2 | ||
START=$3 | ||
END=$4 | ||
for i in $(seq $START $END) | ||
do | ||
CUDA_VISIBLE_DEVICES=${GPU} conda run -n cdsicd bash bin/unifiedqa_clusterft.sh $DS $i | ||
done | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
|
||
GPU=$1 | ||
DS=$2 | ||
START=$3 | ||
END=$4 | ||
for i in $(seq $START $END) | ||
do | ||
CUDA_VISIBLE_DEVICES=${GPU} conda run -n cdsicd bash bin/unifiedqa_oneshot.sh $DS $i | ||
done | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#/bin/bash | ||
|
||
seed=42 | ||
CLUSTER_PATH="${CLUSTER_PATH_ENV}/cluster_assignment.csv" | ||
|
||
DATASET=$1 | ||
CLUSTER_ID="${1}_${2}" | ||
|
||
python -m src.pl_train -c t03b.json+ia3.json+unifiedqa.json+qa_clusterft.json \ | ||
-k load_weight="pretrained_checkpoints/t03b_ia3_finish.pt" \ | ||
exp_name=t03b_unifiedqa_seed${seed}_clusterft_${CLUSTER_ID} \ | ||
few_shot_random_seed=${seed} seed=${seed} few_shot=True \ | ||
eval_epoch_interval=2 max_valid_size=64 dev_qa_subset=$DATASET \ | ||
allow_skip_exp=False lr=3e-2 $ITP cluster_ft_path=$CLUSTER_PATH cluster_id=$CLUSTER_ID |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#/bin/bash | ||
|
||
seed=42 | ||
DATASET=$1 | ||
ONESHOTIDX=$2 | ||
|
||
cd ../../t-few | ||
python -m src.pl_train -c t03b.json+ia3.json+unifiedqa.json+qa_1shot.json \ | ||
-k load_weight="pretrained_checkpoints/t03b_ia3_finish.pt" \ | ||
exp_name=t03b_unifiedqa_seed${seed}_1shot_${DATASET}_${ONESHOTIDX} \ | ||
few_shot_random_seed=${seed} seed=${seed} few_shot=True \ | ||
eval_epoch_interval=2 max_valid_size=64 dev_qa_subset=$DATASET train_qa_subset=$DATASET \ | ||
allow_skip_exp=False lr=3e-2 oneshot_idx=${ONESHOTIDX} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#/bin/bash | ||
|
||
DS=$1 | ||
WEIGHTS="${WEIGHTS_DIR}/t03b_unifiedqa_seed42_1shot_${DS}/" | ||
WEIGHTS="$(echo "$WEIGHTS"chkpt_*.pt)" | ||
|
||
seed=42 | ||
|
||
cd ../ | ||
python -m src.pl_eval -c t03b.json+ia3.json+unifiedqa.json+cds_ppl.json \ | ||
-k load_weight=${WEIGHTS} \ | ||
exp_name=t03b_unifiedqa_seed${seed}_cdsppl_lg_${DS} \ | ||
few_shot_random_seed=${seed} seed=${seed} few_shot=True \ | ||
save_model=False eval_epoch_interval=1 num_shot=100 \ | ||
allow_skip_exp=False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#/bin/bash | ||
|
||
DS=$1 | ||
WEIGHTS="${WEIGHTS_DIR}/t03b_unifiedqa_seed42_clusterft_${DS}/" | ||
WEIGHTS="$(echo "$WEIGHTS"chkpt_*.pt)" | ||
|
||
seed=42 | ||
|
||
cd ../ | ||
python -m src.pl_eval -c t03b.json+ia3.json+unifiedqa.json+cds_ppl.json \ | ||
-k load_weight=${WEIGHTS} \ | ||
exp_name=t03b_unifiedqa_seed${seed}_cdsppl_clusterlg_${DS} \ | ||
few_shot_random_seed=${seed} seed=${seed} few_shot=True \ | ||
save_model=False eval_epoch_interval=1 num_shot=100 \ | ||
allow_skip_exp=False max_valid_size=4000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#/bin/bash | ||
|
||
DS=$1 | ||
WEIGHTS="${WEIGHTS_DIR}/t03b_unifiedqa_seed42_1shot_${DS}/" | ||
WEIGHTS="$(echo "$WEIGHTS"chkpt_*.pt)" | ||
|
||
seed=42 | ||
|
||
cd ../ | ||
python -m src.pl_eval -c t03b.json+ia3.json+unifiedqa.json+cds_ppl.json \ | ||
-k load_weight=${WEIGHTS} \ | ||
exp_name=t03b_unifiedqa_seed${seed}_cdsppl_train_${DS} \ | ||
few_shot_random_seed=${seed} seed=${seed} few_shot=True \ | ||
save_model=False eval_epoch_interval=1 num_shot=256 swap_dev_train=True \ | ||
allow_skip_exp=False max_valid_size=256 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"batch_size": 8, | ||
"eval_batch_size": 32, | ||
"do_generation": false, | ||
"num_shot" : 1, | ||
"max_valid_size":32, | ||
"no_loss_reduce" : true, | ||
"save_model" : false, | ||
"do_icl": false, | ||
"do_ppl": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
"batch_size": 1, | ||
"eval_batch_size": 16, | ||
"do_generation": true, | ||
"num_shot" : 32, | ||
"no_loss_reduce" : false, | ||
"save_model" : true, | ||
"do_icl": false, | ||
"random_icl": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
"batch_size": 8, | ||
"eval_batch_size": 16, | ||
"do_generation": true, | ||
"num_shot" : 256, | ||
"no_loss_reduce" : false, | ||
"save_model" : true, | ||
"do_icl": false, | ||
"random_icl": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"dataset": "unified_qa", | ||
"batch_size": 8, | ||
"eval_batch_size": 16, | ||
"grad_accum_factor": 1, | ||
"num_shot": 5, | ||
"length_norm": 0, | ||
"mc_loss": 0, | ||
"unlikely_loss": 0, | ||
"dataset_path": "/data/unifiedqa", | ||
"do_generation": true, | ||
"max_valid_size": 400, | ||
"num_epochs":20, | ||
"smart_trunc":true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/bin/bash | ||
|
||
# All commands to run experiments. Assumes 1 GPUs on 1 node. | ||
# Supporting more nodes / different number of GPUs can be done by changing params. | ||
# Create conda env named cdsicd | ||
|
||
conda run -n cdsicd pip install -r requirements.txt | ||
|
||
# Train 1 shot in-domain models | ||
bash bin/run_unified_oneshot.sh 0 squad2 0 31 | ||
bash bin/run_unified_oneshot.sh 0 boolq 0 31 | ||
bash bin/run_unified_oneshot.sh 0 narrativeqa 0 31 | ||
bash bin/run_unified_oneshot.sh 0 naturalqa 0 31 | ||
bash bin/run_unified_oneshot.sh 0 newsqa 0 31 | ||
bash bin/run_unified_oneshot.sh 0 npboolq 0 31 | ||
bash bin/run_unified_oneshot.sh 0 openbookqa 0 31 | ||
bash bin/run_unified_oneshot.sh 0 race 0 31 | ||
|
||
|
||
# For each model, compute scores | ||
bash bin/run_unified_cds.sh 0 squad2 0 31 | ||
bash bin/run_unified_cds.sh 0 boolq 0 31 | ||
bash bin/run_unified_cds.sh 0 narrativeqa 0 31 | ||
bash bin/run_unified_cds.sh 0 naturalqa 0 31 | ||
bash bin/run_unified_cds.sh 0 newsqa 0 31 | ||
bash bin/run_unified_cds.sh 0 npboolq 0 31 | ||
bash bin/run_unified_cds.sh 0 openbookqa 0 31 | ||
bash bin/run_unified_cds.sh 0 race 0 31 | ||
|
||
# For training a PEFT model with in-domain examples, evaluate CED scores on training data instead of validation | ||
bash bin/run_unified_cds_train.sh 0 squad2 0 31 | ||
bash bin/run_unified_cds_train.sh 0 boolq 0 31 | ||
bash bin/run_unified_cds_train.sh 0 narrativeqa 0 31 | ||
bash bin/run_unified_cds_train.sh 0 naturalqa 0 31 | ||
bash bin/run_unified_cds_train.sh 0 newsqa 0 31 | ||
bash bin/run_unified_cds_train.sh 0 npboolq 0 31 | ||
bash bin/run_unified_cds_train.sh 0 openbookqa 0 31 | ||
bash bin/run_unified_cds_train.sh 0 race 0 31 | ||
|
||
# For building clustered CED ICD scorers | ||
bash bin/run_unified_clusterft.sh 0 squad2 0 31 | ||
bash bin/run_unified_clusterft.sh 0 boolq 0 31 | ||
bash bin/run_unified_clusterft.sh 0 narrativeqa 0 31 | ||
bash bin/run_unified_clusterft.sh 0 naturalqa 0 31 | ||
bash bin/run_unified_clusterft.sh 0 newsqa 0 31 | ||
bash bin/run_unified_clusterft.sh 0 npboolq 0 31 | ||
bash bin/run_unified_clusterft.sh 0 openbookqa 0 31 | ||
bash bin/run_unified_clusterft.sh 0 race 0 31 | ||
|
||
|
||
# For each model, compute scores | ||
bash bin/run_unified_cds_cluster.sh 0 squad2 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 boolq 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 narrativeqa 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 naturalqa 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 newsqa 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 npboolq 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 openbookqa 0 31 | ||
bash bin/run_unified_cds_cluster.sh 0 race 0 31 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import os | ||
import sys | ||
sys.path.insert(0,'../../t-few/') | ||
import torch | ||
import argparse | ||
from datetime import datetime | ||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
from src.data import FinetuneDataModule, get_dataset_reader, PretrainDataModule | ||
from src.models.EncoderDecoder import EncoderDecoder | ||
from src.models.modify_model import modify_transformer | ||
from src.utils.Config import Config | ||
from src.utils.util import ParseKwargs, set_seeds | ||
|
||
|
||
def get_transformer(config): | ||
tokenizer = AutoTokenizer.from_pretrained(config.origin_model) | ||
model = AutoModelForSeq2SeqLM.from_pretrained(config.origin_model, low_cpu_mem_usage=True) | ||
|
||
tokenizer.model_max_length = config.max_seq_len | ||
model = modify_transformer(model, config) | ||
return tokenizer, model | ||
|
||
|
||
def main(config): | ||
""" | ||
Trains the model | ||
:param config: | ||
:return: | ||
""" | ||
|
||
tokenizer, model = get_transformer(config) | ||
dataset_reader = get_dataset_reader(config) | ||
if config.dataset == "T0Mixture": | ||
datamodule = PretrainDataModule(config, tokenizer, dataset_reader) | ||
else: | ||
datamodule = FinetuneDataModule(config, tokenizer, dataset_reader) | ||
model = EncoderDecoder(config, tokenizer, model, dataset_reader) | ||
logger = TensorBoardLogger(config.exp_dir, name="log") | ||
|
||
trainer = Trainer( | ||
enable_checkpointing=False, | ||
gpus=torch.cuda.device_count(), | ||
#precision=config.compute_precision, | ||
amp_backend="native", | ||
strategy=config.compute_strategy if config.compute_strategy != "none" else None, | ||
logger=logger, | ||
log_every_n_steps=4, | ||
max_steps=1, | ||
min_steps=1, | ||
num_sanity_val_steps=-1 if config.eval_before_training else 0, | ||
check_val_every_n_epoch=config.eval_epoch_interval, | ||
accumulate_grad_batches=config.grad_accum_factor, | ||
gradient_clip_val=config.grad_clip_norm, | ||
) | ||
trainer.validate(model, datamodule) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-c", "--config_files", required=True) | ||
parser.add_argument("-k", "--kwargs", nargs="*", action=ParseKwargs, default={}) | ||
args = parser.parse_args() | ||
|
||
config = Config(args.config_files, args.kwargs) | ||
print(f"Start experiment {config.exp_name}") | ||
# Setup config | ||
assert config.compute_strategy in ["none", "ddp", "deepspeed_stage_3_offload", "deepspeed_stage_3"] | ||
if config.fishmask_mode == "create": | ||
print("Detecting fishmask_mode=create, override batch_size, num_step, fishmask_path") | ||
config.batch_size = 1 | ||
config.num_steps = config.num_shot | ||
config.eval_before_training = False | ||
config.fishmask_path = None | ||
|
||
print(config.to_json()) | ||
|
||
if config.allow_skip_exp and os.path.exists(config.finish_flag_file): | ||
print(f"Skip finished experiment {config.exp_name}") | ||
else: | ||
print(f"Mark experiment {config.exp_name} as claimed") | ||
with open(config.finish_flag_file, "a+") as f: | ||
f.write(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + "\n") | ||
set_seeds(config.seed) | ||
main(config) |