-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
61 lines (51 loc) · 2.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import torch
from load_model import load_models
from utils import get_yaml
from metrics import compute_metrics_sentence_similarity, compute_metrics_sentence_similarity_test
from evaluate import evaluate_contrastive_model
from data.data_collators import data_collator
from data.get_data import get_datasets_test_sentence_sim
from train import train_sim_model
def get_args():
parser = argparse.ArgumentParser(description='Settings')
parser.add_argument('--yaml_path', type=str, help='Path to yaml file settings')
parser.add_argument('--eval', action='store_true', help='Run model in evaluation mode')
parser.add_argument('--token', type=str, help='Huggingface token')
return parser.parse_args()
def main():
### Set up args
args = get_args()
yargs = get_yaml(args.yaml_path)
for key, value in yargs['general_args'].items(): # copy yaml config into args
setattr(args, key, value)
for key, value in yargs['training_args'].items():
setattr(args, key, value)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
### If using wandb
if args.wandb:
import wandb
import os
os.environ['WANDB_API_KEY'] = input('Wandb api key: ')
os.environ['WANDB_PROJECT'] = args.wandb_project
os.environ['WANDB_NAME'] = args.wandb_name
wandb.login()
wandb.init()
print('\n-----Load Model-----\n')
model, tokenizer = load_models(args)
if args.eval:
evaluate_contrastive_model(yargs,
tokenizer=tokenizer,
model=model,
compute_metrics=compute_metrics_sentence_similarity_test,
get_dataset=get_datasets_test_sentence_sim,
data_collator=data_collator,
token=args.token)
else:
train_sim_model(yargs,
model,
tokenizer,
compute_metrics=compute_metrics_sentence_similarity,
token=args.token)
if __name__ == '__main__':
main()