-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
84 changed files
with
5,176 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,91 @@ | ||
# Learning to Retrieve In-Context Examples for Large Language Models | ||
|
||
## News | ||
- [Paper Release] July, 2023: [Learning to Retrieve In-Context Examples for Large Language Models](https://arxiv.org/abs/2307.07164) | ||
# llm-retriever | ||
|
||
This repository contains the code for our paper [Learning to Retrieve In-Context Examples for Large Language Models](https://arxiv.org/abs/2307.07164). | ||
|
||
Large language models (LLMs) have demonstrated their ability to learn in-context, allowing them to perform various tasks based on a few input-output examples. However, the effectiveness of in-context learning is heavily reliant on the quality of the selected examples. In this paper, we propose a novel framework to iteratively train dense retrievers that can identify high-quality in-context examples for LLMs. Our framework initially trains a reward model based on LLM feedback to evaluate the quality of candidate examples, followed by knowledge distillation to train a bi-encoder based dense retriever. Our experiments on a suite of 30 tasks demonstrate that our framework significantly enhances in-context learning performance. Furthermore, we show the generalization ability of our framework to unseen tasks during training. An in-depth analysis reveals that our model improves performance by retrieving examples with similar patterns, and the gains are consistent across LLMs of varying sizes. | ||
|
||
 | ||
|
||
## Prerequisites | ||
|
||
### Download Data | ||
|
||
Please run the following command to download our preprocessed data from HuggingFace Datasets. | ||
|
||
```shell | ||
bash scripts/download_data.sh | ||
``` | ||
|
||
You can browse the dataset at `https://huggingface.co/datasets/intfloat/llm-retriever-tasks`. | ||
|
||
We also provide a script at `misc/format_all_tasks.py` to convert the original data format to the one used in our codebase. | ||
|
||
### Install Dependencies | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Evaluate Our Released Checkpoint | ||
|
||
Specify the output directory with `OUTPUT_DIR` and run the following command to evaluate our released checkpoint `intfloat/llm-retriever-base`. | ||
|
||
```shell | ||
OUTPUT_DIR=outputs/llm-retriever-base/ bash scripts/eval_retriever.sh intfloat/llm-retriever-base | ||
``` | ||
|
||
## Train from Our Released Reward Model Scores | ||
|
||
To reproduce our best reported results in the paper, | ||
you can use the following command to train a retriever from our released reward model scores. | ||
|
||
```shell | ||
OUTPUT_DIR=outputs/repro_llmr_it2/ bash scripts/train_kd_biencoder.sh | ||
``` | ||
|
||
## Generate Training Data with LLaMA-7B | ||
|
||
To generate training data from LLM feedback, | ||
you can use the following command to generate a scoring file with LLaMA-7B. | ||
It works by scoring the top candidates from BM25 retrieval results. | ||
|
||
```shell | ||
bash scripts/gen_llm_score.sh huggyllama/llama-7b bm25_train | ||
``` | ||
|
||
It should produce a new scoring file at `data/tasks/llama-7b_bm25_train.jsonl.gz`. | ||
|
||
## Train Reward Model from Scratch | ||
|
||
After you have generated training data with LLaMA-7B, | ||
you can use the following command to train the reward model. | ||
|
||
```shell | ||
OUTPUT_DIR=outputs/repro_reward_it1/ bash scripts/train_reward.sh | ||
``` | ||
|
||
## Iterative Training | ||
|
||
Use the below command to retrieve a new set of top-k candidates with a trained retriever, | ||
and then follow the training data generation process for iterative training. | ||
|
||
```shell | ||
OUTPUT_DIR=outputs/search_topk/ bash scripts/search_topk.sh /path/to/trained/retriever | ||
``` | ||
|
||
## Citation | ||
|
||
If you find this repository useful, please cite our paper: | ||
|
||
```bibtex | ||
@article{wang2023learning, | ||
title={Learning to Retrieve In-Context Examples for Large Language Models}, | ||
author={Wang, Liang and Yang, Nan and Wei, Furu}, | ||
journal={arXiv preprint arXiv:2307.07164}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
## Issues | ||
|
||
Please feel free to open a GitHub issue if you have any questions or find any bugs. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
"fp16": { | ||
"enabled": "auto", | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 10, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
|
||
"optimizer": { | ||
"type": "AdamW", | ||
"params": { | ||
"lr": "auto", | ||
"betas": "auto", | ||
"eps": "auto", | ||
"weight_decay": "auto" | ||
} | ||
}, | ||
|
||
"scheduler": { | ||
"type": "WarmupDecayLR", | ||
"params": { | ||
"warmup_min_lr": "auto", | ||
"warmup_max_lr": "auto", | ||
"warmup_num_steps": "auto", | ||
"total_num_steps": "auto" | ||
} | ||
}, | ||
|
||
"zero_optimization": { | ||
"stage": 2, | ||
"allgather_partitions": true, | ||
"allgather_bucket_size": 2e8, | ||
"overlap_comm": true, | ||
"reduce_scatter": true, | ||
"reduce_bucket_size": 2e8, | ||
"contiguous_gradients": true | ||
}, | ||
|
||
"gradient_accumulation_steps": "auto", | ||
"gradient_clipping": "auto", | ||
"steps_per_print": 2000, | ||
"train_batch_size": "auto", | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"wall_clock_breakdown": false | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
import sys | ||
import json | ||
import argparse | ||
|
||
sys.path.insert(0, 'src/') | ||
|
||
from typing import List | ||
from datasets import Dataset, concatenate_datasets | ||
|
||
from utils import save_dataset | ||
from tasks import task_map, BaseTask | ||
from logger_config import logger | ||
|
||
parser = argparse.ArgumentParser(description='data preprocessing for all tasks') | ||
parser.add_argument('--output-dir', default='./data/tasks/', | ||
type=str, metavar='N', help='output directory') | ||
parser.add_argument('--template-idx', default=0, type=int, metavar='N', | ||
help='template index for the task') | ||
parser.add_argument('--max-train-examples', default=30_000, type=int, metavar='N', | ||
help='maximum number of training examples per task') | ||
|
||
args = parser.parse_args() | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) | ||
|
||
|
||
def format_and_save_corpus(): | ||
corpus_list: List[Dataset] = [] | ||
for task_name, task_cls in task_map.cls_dic.items(): | ||
task: BaseTask = task_cls(template_idx=args.template_idx) | ||
logger.info('Task: {}'.format(task_name)) | ||
task_corpus: Dataset = task.get_corpus() | ||
if task_corpus is None: | ||
continue | ||
|
||
logger.info('Task: {}, corpus size: {}'.format(task_name, len(task_corpus))) | ||
corpus_list.append(task_corpus) | ||
|
||
corpus: Dataset = concatenate_datasets(corpus_list) | ||
corpus = corpus.add_column('id', [str(i) for i in range(len(corpus))]) | ||
|
||
out_path: str = '{}/passages.jsonl.gz'.format(args.output_dir) | ||
save_dataset(corpus, out_path=out_path) | ||
logger.info('Save {} lines to {}'.format(len(corpus), out_path)) | ||
|
||
|
||
def prepare_split(split: str = 'test'): | ||
dataset_list: List[Dataset] = [] | ||
for task_name, task_cls in task_map.cls_dic.items(): | ||
task: BaseTask = task_cls(template_idx=args.template_idx) | ||
logger.info('Task: {}'.format(task_name)) | ||
task_ds: Dataset = task.get_task_data(split=split) | ||
if task_ds is None: | ||
continue | ||
|
||
logger.info('Task: {}, size: {}'.format(task_name, len(task_ds))) | ||
if split == 'train' and len(task_ds) > args.max_train_examples: | ||
task_ds = task_ds.shuffle().select(range(args.max_train_examples)) | ||
logger.info('Random sample to {} examples'.format(len(task_ds))) | ||
dataset_list.append(task_ds) | ||
|
||
dataset: Dataset = concatenate_datasets(dataset_list) | ||
|
||
out_path: str = os.path.join(args.output_dir, '{}.jsonl.gz'.format(split)) | ||
save_dataset(dataset, out_path) | ||
logger.info('Save {} examples to {}'.format(len(dataset), out_path)) | ||
|
||
|
||
def main(): | ||
format_and_save_corpus() | ||
for split in ['train', 'test']: | ||
prepare_split(split) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
torch>=1.12 | ||
transformers==4.28 | ||
datasets==2.3.0 | ||
deepspeed==0.8.3 | ||
tqdm | ||
rouge | ||
faiss-cpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -x | ||
set -e | ||
|
||
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )" | ||
echo "working directory: ${DIR}" | ||
|
||
mkdir -p data/tasks/ | ||
|
||
for DATA_FILE in "train.jsonl.gz" "test.jsonl.gz" "passages.jsonl.gz" "bm25_train.jsonl.gz" "kd_bm25_train.jsonl.gz" "kd_it2_train.jsonl.gz"; do | ||
if [ ! -e data/tasks/${DATA_FILE} ]; then | ||
wget -O data/tasks/${DATA_FILE} https://huggingface.co/datasets/intfloat/llm-retriever-tasks/resolve/main/${DATA_FILE} | ||
fi | ||
done | ||
|
||
echo "data downloaded successfully!" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -x | ||
set -e | ||
|
||
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )" | ||
echo "working directory: ${DIR}" | ||
|
||
MODEL_NAME_OR_PATH="random" | ||
if [[ $# -ge 1 && ! "$1" == "--"* ]]; then | ||
MODEL_NAME_OR_PATH=$1 | ||
shift | ||
fi | ||
|
||
LLM_MODEL_NAME_OR_PATH="huggyllama/llama-7b" | ||
if [[ $# -ge 1 && ! "$1" == "--"* ]]; then | ||
LLM_MODEL_NAME_OR_PATH=$1 | ||
shift | ||
fi | ||
|
||
if [ -z "$OUTPUT_DIR" ]; then | ||
OUTPUT_DIR="${MODEL_NAME_OR_PATH}" | ||
fi | ||
if [ -z "$DATA_DIR" ]; then | ||
DATA_DIR="${DIR}/data/tasks/" | ||
fi | ||
|
||
N_SHOTS=8 | ||
EVAL_TASKS=("all") | ||
|
||
PYTHONPATH=src/ python -u src/inference/generate_few_shot_prompt.py \ | ||
--model_name_or_path "${MODEL_NAME_OR_PATH}" \ | ||
--seed 1234 \ | ||
--fp16 \ | ||
--llm_eval_tasks "${EVAL_TASKS[@]}" \ | ||
--llm_eval_split test \ | ||
--llm_k_shot "${N_SHOTS}" \ | ||
--output_dir "${OUTPUT_DIR}" \ | ||
--data_dir "${DATA_DIR}" | ||
|
||
|
||
PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l) | ||
# EleutherAI/gpt-neo-2.7B # huggyllama/llama-7b | ||
python -u -m torch.distributed.launch --nproc_per_node "${PROC_PER_NODE}" src/main_eval.py \ | ||
--model_name_or_path "${MODEL_NAME_OR_PATH}" \ | ||
--seed 1234 \ | ||
--fp16 \ | ||
--do_llm_eval \ | ||
--llm_model_name_or_path "${LLM_MODEL_NAME_OR_PATH}" \ | ||
--llm_batch_size_per_device 4 \ | ||
--llm_eval_tasks "${EVAL_TASKS[@]}" \ | ||
--llm_eval_split test \ | ||
--llm_k_shot "${N_SHOTS}" \ | ||
--llm_max_input_length 1024 \ | ||
--llm_max_decode_length 64 \ | ||
--output_dir "${OUTPUT_DIR}" \ | ||
--data_dir "${DATA_DIR}" \ | ||
--overwrite_output_dir \ | ||
--disable_tqdm True \ | ||
--report_to none "$@" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -x | ||
set -e | ||
|
||
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )" | ||
echo "working directory: ${DIR}" | ||
|
||
# EleutherAI/gpt-neo-2.7B | ||
MODEL_NAME_OR_PATH="huggyllama/llama-7b" | ||
if [[ $# -ge 1 && ! "$1" == "--"* ]]; then | ||
MODEL_NAME_OR_PATH=$1 | ||
shift | ||
fi | ||
|
||
SPLIT="bm25_train" | ||
if [[ $# -ge 1 && ! "$1" == "--"* ]]; then | ||
SPLIT=$1 | ||
shift | ||
fi | ||
|
||
if [ -z "$DATA_DIR" ]; then | ||
DATA_DIR="${DIR}/data/tasks/" | ||
fi | ||
|
||
PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l) | ||
# For gpt-neo-2.7B, set batch_size_per_device to 16 | ||
# llama-7b, set batch_size_per_device to 8 | ||
PYTHONPATH=src/ python -u -m torch.distributed.launch --nproc_per_node ${PROC_PER_NODE} src/inference/gen_llm_scores.py \ | ||
--llm_model_name_or_path "${MODEL_NAME_OR_PATH}" \ | ||
--fp16 \ | ||
--search_split "${SPLIT}" \ | ||
--search_topk 32 \ | ||
--llm_batch_size_per_device 8 \ | ||
--max_train_samples 200000 \ | ||
--output_dir "/tmp/" \ | ||
--data_dir "${DATA_DIR}" \ | ||
--report_to none "$@" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -x | ||
set -e | ||
|
||
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )" | ||
echo "working directory: ${DIR}" | ||
|
||
MODEL_NAME_OR_PATH="" | ||
if [[ $# -ge 1 && ! "$1" == "--"* ]]; then | ||
MODEL_NAME_OR_PATH=$1 | ||
shift | ||
fi | ||
|
||
SPLIT="bm25_train" | ||
if [[ $# -ge 1 && ! "$1" == "--"* ]]; then | ||
SPLIT=$1 | ||
shift | ||
fi | ||
|
||
if [ -z "$OUTPUT_DIR" ]; then | ||
OUTPUT_DIR="${MODEL_NAME_OR_PATH}" | ||
fi | ||
if [ -z "$DATA_DIR" ]; then | ||
DATA_DIR="${DIR}/data/tasks/" | ||
fi | ||
|
||
PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l) | ||
PYTHONPATH=src/ python -u -m torch.distributed.launch --nproc_per_node "${PROC_PER_NODE}" src/inference/gen_reward_scores.py \ | ||
--model_name_or_path "${MODEL_NAME_OR_PATH}" \ | ||
--do_kd_gen_score \ | ||
--fp16 \ | ||
--data_dir "${DATA_DIR}" \ | ||
--kd_gen_score_split "${SPLIT}" \ | ||
--kd_gen_score_batch_size 128 \ | ||
--dataloader_num_workers 1 \ | ||
--output_dir "${OUTPUT_DIR}" \ | ||
--report_to none "$@" |
Oops, something went wrong.