forked from megagonlabs/watchog
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_finetune_gittables22_colwise_repeat.py
79 lines (72 loc) · 3.31 KB
/
run_finetune_gittables22_colwise_repeat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import subprocess
import time
import pickle
from multiprocessing import Process
from multiprocessing import Semaphore
'''run finetuning and evaluation on original datasets'''
# task = 'turl-re'
task = 'gt-semtab22-dbpedia-all0'
# task = 'turl'
ml = 128 # 32
bs = 16 # 16
# n_epochs = 10
base_model = 'bert-base-uncased'
# base_model = 'distilbert-base-uncased'
cl_tag = "wikitables/simclr/bert_None_10_32_256_5e-05_sample_row4,sample_row4_tfidf_entity_column_0.05_0_last.pt"
ckpt_path = "/data/zhihao/TU/Watchog/model/"
dropout_prob = 0.1
from_scratch = True
# from_scratch = True # True means using Huggingface's pre-trained language model's checkpoint
eval_test = True
colpair = False
gpus = '0'
max_num_col = 2
comment = "max-unlabeled@{}".format(max_num_col)
# small_tag = 'semi1'
# ml = 64 # 32
# gpus = '2'
# pool = 'v0'
# rand = False
# use_token_type_ids = True
# sampling_method = 'bm25'
# ctype = "v1.2"
# repeat = 10
# for max_num_col in [4]:
# comment = "Repeat@{}-pool@{}-context@{}-max_num_col@{}-use_token_type@{}-sampling_method@{}".format(repeat, pool, ctype, max_num_col, use_token_type_ids, sampling_method)
# for task in ['gt-semtab22-dbpedia-all0']:
# cmd = '''CUDA_VISIBLE_DEVICES={} python supcl_ft_colwise_repeat.py --wandb True \
# --repeat {} --shortcut_name {} --task {} --max_length {} --max_num_col {} --context_encoding_type {} --pool_version {} --sampling_method {} --batch_size {} --use_token_type_ids {} --epoch {} \
# --dropout_prob {} --pretrained_ckpt_path "{}" --cl_tag {} --small_tag "{}" --comment "{}" {} {} {}'''.format(
# gpus, repeat, base_model, task, ml, max_num_col, ctype, pool, sampling_method, bs, use_token_type_ids, n_epochs, dropout_prob,
# ckpt_path, cl_tag, small_tag, comment,
# '--colpair' if colpair else '',
# '--from_scratch' if from_scratch else '',
# '--eval_test' if eval_test else ''
# )
# # os.system('{} & '.format(cmd))
# subprocess.run(cmd, shell=True, check=True)
small_tag = 'semi1'
ml = 64 # 32
n_epochs = 50
gpus = '1'
pool = 'v0'
rand = False
use_token_type_ids = True
sampling_method = None
ctype = "v1.2"
repeat = 5
for max_num_col in [4]:
comment = "Repeat@{}-pool@{}-context@{}-max_num_col@{}-use_token_type@{}-sampling_method@{}".format(repeat, pool, ctype, max_num_col, use_token_type_ids, sampling_method)
for task in ['gt-semtab22-dbpedia-all0']:
cmd = '''CUDA_VISIBLE_DEVICES={} python supcl_ft_colwise_repeat.py --wandb True \
--repeat {} --shortcut_name {} --task {} --max_length {} --max_num_col {} --context_encoding_type {} --pool_version {} --sampling_method {} --batch_size {} --use_token_type_ids {} --epoch {} \
--dropout_prob {} --pretrained_ckpt_path "{}" --cl_tag {} --small_tag "{}" --comment "{}" {} {} {}'''.format(
gpus, repeat, base_model, task, ml, max_num_col, ctype, pool, sampling_method, bs, use_token_type_ids, n_epochs, dropout_prob,
ckpt_path, cl_tag, small_tag, comment,
'--colpair' if colpair else '',
'--from_scratch' if from_scratch else '',
'--eval_test' if eval_test else ''
)
# os.system('{} & '.format(cmd))
subprocess.run(cmd, shell=True, check=True)