-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update code for cc3m lora training llm
- Loading branch information
Showing
12 changed files
with
2,033 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
|
||
## 💻 How to Install | ||
You can also refer to [llm2vec](https://github.com/McGill-NLP/llm2vec) | ||
``` | ||
conda create -n llm2vec python=3.10 -y | ||
conda activate llm2vec | ||
pip install llm2vec | ||
pip install flash-attn --no-build-isolation | ||
pip install deepspeed | ||
pip install accelerate==0.34.2 # the scripts provided by llm2vec can't be directly runned in the newest accelerate | ||
``` | ||
### 🔥 Training | ||
We train all the models in 8*80g h100. | ||
For mntp with cc3m | ||
1. Prepare cc3m with key short_caption and long_caption in csv format | ||
2. cc3m.csv path in MetaLlama3_cc3m.json | ||
|
||
And you can train whth the following scripts | ||
```cd llm2vec | ||
HF_TOKEN=xxxx accelerate launch --config_file ./ac_zero2.yaml run_mntp.py train_configs/mntp/MetaLlama3_cc3m.json | ||
``` | ||
For supervised with cc3m | ||
1. First prepare e5 data used in [llm2vec](https://github.com/McGill-NLP/llm2vec) | ||
,also prepare the same cc3m csv | ||
2. Add the lora weights pretrained in mntp in train_configs/supervised/MetaLlama3_cc3m.json | ||
|
||
And you can train with the following scripts | ||
``` | ||
HF_TOKEN=xxxx --config_file ./ac_zero2.yaml run_supervised.py train_configs/supervised/MetaLlama3_cc3m.json | ||
``` | ||
|
||
## ❤️ Acknowlegement | ||
|
||
This code is built on top of [llm2vec](https://github.com/McGill-NLP/llm2vec). Thanks for their nice work! |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
compute_environment: LOCAL_MACHINE | ||
debug: false | ||
deepspeed_config: | ||
gradient_accumulation_steps: 1 | ||
offload_optimizer_device: none | ||
offload_param_device: none | ||
zero3_init_flag: false | ||
zero_stage: 2 | ||
distributed_type: DEEPSPEED | ||
downcast_bf16: 'no' | ||
dynamo_config: | ||
dynamo_backend: INDUCTOR | ||
enable_cpu_affinity: false | ||
machine_rank: 0 | ||
main_training_function: main | ||
mixed_precision: bf16 | ||
num_machines: 1 | ||
num_processes: 8 | ||
rdzv_backend: static | ||
same_network: true | ||
tpu_env: [] | ||
tpu_use_cluster: false | ||
tpu_use_sudo: false | ||
use_cpu: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
import json | ||
import random | ||
import os | ||
|
||
from llm2vec.dataset.dataset import DataSample, TrainSample, Dataset | ||
from accelerate.logging import get_logger | ||
import pandas as pd | ||
logger = get_logger(__name__, log_level="INFO") | ||
|
||
E5_EMBEDDING_PROMPTS = { | ||
"allnli": [ | ||
"Given a premise, retrieve a hypothesis that is entailed by the premise", | ||
"Retrieve semantically similar text", | ||
], | ||
"dureader": "Given a Chinese search query, retrieve web passages that answer the question", | ||
"eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum", | ||
"fever": "Given a claim, retrieve documents that support or refute the claim", | ||
"hotpot_qa": "Given a multi-hop question, retrieve documents that can help answer the question", | ||
"miracl": "Given a question, retrieve Wikipedia passages that answer the question", | ||
"mrtydi": "Given a question, retrieve Wikipedia passages that answer the question", | ||
"msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query", | ||
"msmarco_document": "Given a web search query, retrieve relevant documents that answer the query", | ||
"nq": "Given a question, retrieve Wikipedia passages that answer the question", | ||
"quora_duplicates": [ | ||
"Given a question, retrieve questions that are semantically equivalent to the given question", | ||
"Find questions that have the same meaning as the input question", | ||
], | ||
"squad": "Retrieve Wikipedia passages that answer the question", | ||
"t2ranking": "Given a Chinese search query, retrieve web passages that answer the question", | ||
"trivia_qa": "Retrieve Wikipedia passages that answer the question", | ||
"cc3m": [ | ||
"Given a sentence, retrieve a detailed relevant sentence", | ||
"Given a detailed sentence, retrieve a short relevant sentence", | ||
], | ||
} | ||
|
||
|
||
class E5Data(Dataset): | ||
def __init__( | ||
self, | ||
dataset_name: str = "E5", | ||
split: str = "validation", | ||
file_path: str = "cache/echo-data", | ||
effective_batch_size: int = 32, | ||
shuffle_individual_datasets: bool = True, | ||
separator: str = "!@#$%^&*()", | ||
extra_cc3m:str = None, | ||
): | ||
self.dataset_name = dataset_name | ||
self.split = split | ||
self.effective_batch_size = effective_batch_size | ||
self.shuffle_individual_datasets = shuffle_individual_datasets | ||
self.separator = separator | ||
self.extra_cc3m = extra_cc3m | ||
|
||
self.data = [] | ||
self.load_data(file_path) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
def get_cc3m(self): | ||
cc3m = pd.read_csv(self.extra_cc3m) | ||
data = cc3m[['short_caption', 'long_caption']].values | ||
# 构建list of dict | ||
list_of_dict = [] | ||
|
||
# 遍历每一行 | ||
for i in range(len(data)): | ||
short_caption = data[i][0] | ||
long_caption = data[i][1] | ||
|
||
# 生成一个0或1的随机数 | ||
rand_num = random.randint(0, 1) | ||
|
||
# 根据随机数设置query和positive | ||
if rand_num == 0: | ||
query = short_caption | ||
positive = long_caption | ||
else: | ||
query = long_caption | ||
positive = short_caption | ||
|
||
# 随机选择一个不同于当前行的neg | ||
while True: | ||
rand_index = random.randint(0, len(data) - 1) | ||
if rand_index != i: # 确保neg和当前行不同 | ||
neg = data[rand_index][1] if rand_num == 0 else data[rand_index][0] | ||
break | ||
|
||
# 构建字典并添加到列表中 | ||
list_of_dict.append({ | ||
'query': query, | ||
'positive': positive, | ||
'negative': neg, | ||
'random_num': rand_num | ||
}) | ||
logger.info(f"Loaded {len(list_of_dict)} samples.") | ||
return list_of_dict | ||
def load_data(self, file_path: str = None): | ||
logger.info(f"Loading E5 data from {file_path}...") | ||
# file path is actually a directory | ||
|
||
data_map = {} | ||
all_samples = [] | ||
id_ = 0 | ||
for dataset in E5_EMBEDDING_PROMPTS: | ||
logger.info(f"Loading dataset {dataset}...") | ||
if dataset not in data_map: | ||
data_map[dataset] = [] | ||
if dataset == "cc3m": | ||
if self.extra_cc3m is not None: | ||
dataset_samples = self.get_cc3m() | ||
else: | ||
continue | ||
else: | ||
with open(os.path.join(file_path, f"{dataset}.jsonl"), "r") as f: | ||
dataset_samples = f.readlines() | ||
dataset_samples = [json.loads(d) for d in dataset_samples] | ||
|
||
for i, sample in enumerate(dataset_samples): | ||
if dataset == "cc3m": | ||
instruction = ( | ||
E5_EMBEDDING_PROMPTS[dataset][sample["random_num"]] | ||
) | ||
else: | ||
instruction = ( | ||
E5_EMBEDDING_PROMPTS[dataset] | ||
if isinstance(E5_EMBEDDING_PROMPTS[dataset], str) | ||
else E5_EMBEDDING_PROMPTS[dataset][i % 2] | ||
) | ||
query = f"{instruction}; " + self.separator + sample["query"] | ||
if dataset in [ | ||
"allnli_split2", | ||
"quora_duplicates_split1", | ||
"quora_duplicates_split2", | ||
]: | ||
pos = ( | ||
f"{E5_EMBEDDING_PROMPTS[dataset]}; " | ||
+ self.separator | ||
+ sample["positive"] | ||
) | ||
neg = ( | ||
f"{E5_EMBEDDING_PROMPTS[dataset]}; " | ||
+ self.separator | ||
+ sample["negative"] | ||
) | ||
else: | ||
pos = self.separator + sample["positive"] | ||
neg = self.separator + sample["negative"] | ||
|
||
data_map[dataset].append(id_) | ||
|
||
all_samples.append( | ||
DataSample( | ||
id_=id_, | ||
query=query, | ||
positive=pos, | ||
negative=neg, | ||
task_name=dataset, | ||
) | ||
) | ||
id_ += 1 | ||
|
||
# combine split1 and split2 | ||
new_data_map = {} | ||
for dataset in data_map: | ||
new_dataset = dataset.replace("_split1", "").replace("_split2", "") | ||
if new_dataset not in new_data_map: | ||
new_data_map[new_dataset] = [] | ||
new_data_map[new_dataset] += data_map[dataset] | ||
data_map = new_data_map | ||
|
||
if self.shuffle_individual_datasets: | ||
for task, samples in data_map.items(): | ||
random.shuffle(samples) | ||
|
||
datasets = list(data_map.keys()) | ||
|
||
logger.info( | ||
f"Batching Echo data properly for effective batch size of {self.effective_batch_size}..." | ||
) | ||
all_batches = [] | ||
for dataset in datasets: | ||
dataset_samples = data_map[dataset] | ||
for i in range(0, len(dataset_samples), self.effective_batch_size): | ||
batch = dataset_samples[i : i + self.effective_batch_size] | ||
if len(batch) == self.effective_batch_size: | ||
all_batches.append(batch) | ||
else: | ||
logger.info(f"Skip 1 batch for dataset {dataset}.") | ||
random.shuffle(all_batches) | ||
|
||
final_idx_order = [] | ||
for batch in all_batches: | ||
for idx in batch: | ||
final_idx_order.append(idx) | ||
|
||
self.data = [all_samples[idx] for idx in final_idx_order] | ||
logger.info(f"Loaded {len(self.data)} samples.") | ||
|
||
def __getitem__(self, index): | ||
sample = self.data[index] | ||
if self.split == "train": | ||
return TrainSample( | ||
texts=[sample.query, sample.positive, sample.negative], label=1.0 | ||
) | ||
elif self.split == "validation": | ||
assert False, "E5Data does not have a validation split." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from llm2vec.dataset.dataset import DataSample, TrainSample, Dataset | ||
import numpy as np | ||
from accelerate.logging import get_logger | ||
import pandas as pd | ||
import datasets | ||
logger = get_logger(__name__, log_level="INFO") | ||
|
||
|
||
class Captions(Dataset): | ||
def __init__( | ||
self, | ||
dataset_name: str = "captions", | ||
split: str = "validation", | ||
file_path: str = 'cc3m.csv', | ||
wiki1m=None, | ||
): | ||
self.dataset_name = dataset_name | ||
self.split = split | ||
self.data = [] | ||
if wiki1m is not None: | ||
self.data = wiki1m.data | ||
self.load_data(file_path) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def load_data(self, file_path: str = None): | ||
logger.info(f"Loading captions data from {file_path}...") | ||
id_ = len(self.data) | ||
cc3m = pd.read_csv(file_path) | ||
data = cc3m[['short_caption', 'long_caption']].values | ||
|
||
# 使用 NumPy 随机选取每一行中的一个字符串值 | ||
selected_values = np.choose(np.random.randint(2, size=len(data)), data.T) | ||
for line in selected_values: | ||
line = line.strip() | ||
self.data.append( | ||
DataSample( | ||
id_=id_, | ||
query=line, | ||
positive=line, | ||
) | ||
) | ||
id_ += 1 | ||
logger.info(f"Loaded {len(self.data)} samples.") | ||
|
||
def __getitem__(self, index): | ||
sample = self.data[index] | ||
if self.split == "train": | ||
return TrainSample(texts=[sample.query, sample.positive], label=1.0) | ||
elif self.split == "validation": | ||
assert False, "Wiki1M does not have a validation split." | ||
def get_cc3m_captions(file_path: str = 'cc3m.csv'): | ||
cc3m = pd.read_csv(file_path) | ||
data = cc3m[['short_caption', 'long_caption']].values | ||
selected_values = np.choose(np.random.randint(2, size=len(data)), data.T) | ||
df = pd.DataFrame(selected_values, columns=['text']) | ||
cc3m = datasets.Dataset.from_pandas(df) | ||
return cc3m | ||
def merge_cc3m_wikiraw103(cc3m=None, wiki1m:datasets.Dataset=None): | ||
train_ratio = wiki1m['train'].num_rows / sum([wiki1m['train'].num_rows, wiki1m['validation'].num_rows, wiki1m['test'].num_rows]) | ||
cc3m = cc3m.train_test_split(train_size=train_ratio,seed=42) | ||
new_train = datasets.concatenate_datasets([cc3m['train'], wiki1m['train']]) | ||
new_val = datasets.concatenate_datasets([cc3m['test'], wiki1m['validation']]) | ||
new_dataset = datasets.DatasetDict({'train':new_train, 'validation':new_val}) | ||
logger.info(f"New dataset created with {new_dataset['train'].num_rows} training samples and {new_dataset['validation'].num_rows} validation samples.") | ||
return new_dataset |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .E5_cc3m import E5Data | ||
|
||
|
||
def load_dataset(dataset_name, split="validation", file_path=None, **kwargs): | ||
""" | ||
Loads a dataset by name. | ||
Args: | ||
dataset_name (str): Name of the dataset to load. | ||
split (str): Split of the dataset to load. | ||
file_path (str): Path to the dataset file. | ||
""" | ||
dataset_mapping = { | ||
"E5": E5Data, | ||
} | ||
|
||
if dataset_name not in dataset_mapping: | ||
raise NotImplementedError(f"Dataset name {dataset_name} not supported.") | ||
|
||
if split not in ["train", "validation", "test"]: | ||
raise NotImplementedError(f"Split {split} not supported.") | ||
|
||
return dataset_mapping[dataset_name](split=split, file_path=file_path, **kwargs) |
Oops, something went wrong.