-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain_summary_loop.py
241 lines (187 loc) · 11.1 KB
/
train_summary_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from torch.utils.data import DataLoader, RandomSampler
from utils_dataset import SQLDataset, HDF5Dataset
import torch, os, time, argparse, numpy as np
from transformers.optimization import AdamW
from model_generator import GeneTransformer
import utils_misc, utils_tokenizer
from utils_logplot import LogPlot
from datetime import datetime
from model_coverage import KeywordCoverage
from model_guardrails import PatternPenalty, LengthPenalty, RepeatPenalty
import threading, queue
user = os.getlogin()
parser = argparse.ArgumentParser()
parser.add_argument("--experiment", type=str, required=True, help="Experiment name. Will be used to save a model file and a log file.")
parser.add_argument("--dataset_file", type=str, required=True, help="Which dataset file to use. Can be full path or the root folder will be attached.")
parser.add_argument("--train_batch_size", type=int, default=5, help="Training batch size.")
parser.add_argument("--n_epochs", type=int, default=3, help="Number of epochs to run over the data.")
parser.add_argument("--optim_every", type=int, default=4, help="Optimize every x backprops. A multiplier to the true batch size.")
parser.add_argument("--max_output_length", type=int, default=25, help="Maximum output length. Saves time if the sequences are short.")
parser.add_argument("--save_every", type=int, default=60, help="Number of seconds between any two saves.")
parser.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument("--ckpt_every", type=int, default=600, help="If 0, checkpointing is not used. Otherwise, checkpointing is done very x seconds.")
parser.add_argument("--ckpt_lookback", type=int, default=300, help="When checkpointing, will consider the avg total score of the last x samples.")
args = parser.parse_args()
if args.device == "cuda":
freer_gpu = str(utils_misc.get_freer_gpu())
os.environ["CUDA_VISIBLE_DEVICES"] = ""+str(freer_gpu)
args.experiment += "_"+freer_gpu
models_folder = "/home/phillab/models/"
log_folder = "/home/phillab/logs/"
summarizer_model_start = os.path.join(models_folder, "gpt2_copier23.bin")
ckpt_every = args.ckpt_every
ckpt_lookback = int((args.ckpt_lookback+args.train_batch_size-1)/args.train_batch_size)
total_score_history = []
best_ckpt_score = None
ckpt_file = os.path.join(models_folder, "summarizer_"+args.experiment+"_ckpt.bin")
ckpt_optimizer_file = os.path.join(models_folder, "summarizer_optimizer_"+args.experiment+"_ckpt.bin")
learning_rate = 2e-5
n_epochs = args.n_epochs
if args.device == "cuda":
print("Training on GPU "+str(freer_gpu))
bert_tokenizer = utils_tokenizer.BERTCacheTokenizer()
print("---------------")
summarizer = GeneTransformer(max_output_length=args.max_output_length, device=args.device, tokenizer_type='gpt2', starter_model=summarizer_model_start)
print("Summarizer loaded")
def collate_func(inps):
if ".db" in args.dataset_file:
return [a['body'] for a in inps]
else:
return [inp[0].decode() for inp in inps]
param_optimizer = list(summarizer.model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
logplot_file = os.path.join(log_folder, "summary_loop_%s.log" % (args.experiment))
logplot = LogPlot(logplot_file)
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
time_save = time.time()
time_ckpt = time.time()
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
summarizer.model, optimizer = amp.initialize(summarizer.model, optimizer, opt_level="O1") # For now O1. See details at https://nvidia.github.io/apex/amp.html
print("Loading scorers")
coverage_model_file = os.path.join(models_folder, "bert_coverage_google_cnndm_length15_1.bin")
coverage_keyword_model_file = os.path.join(models_folder, "keyword_extractor.joblib")
fluency_news_model_file = os.path.join(models_folder, "news_gpt2_bs32.bin")
scorers = [{"name": "coverage", "importance": 10.0, "sign": 1.0, "model": KeywordCoverage(args.device, keyword_model_file=coverage_keyword_model_file, model_file=coverage_model_file)},
{"name": "fluency", "importance": 2.0, "sign": 1.0, "model": GeneTransformer(max_output_length=args.max_output_length, device=args.device, starter_model=fluency_news_model_file)},
{"name": "patpen", "importance": 5.0, "sign": -1.0, "model": PatternPenalty()},
{"name": "lengthpen", "importance": 2.0, "sign": -1.0, "model": LengthPenalty(args.max_output_length)},
{"name": "reppen", "importance": 2.0, "sign": -1.0, "model": RepeatPenalty()}
]
def background_tokenizer(bodies, out_queue):
out_queue.put([bert_tokenizer.encode(body) for body in bodies])
my_queue = queue.Queue()
print("Started training")
if ".db" in args.dataset_file:
all_dataset = SQLDataset(args.dataset_file)
else:
all_dataset = HDF5Dataset(args.dataset_file, collection_name="name")
dataset = all_dataset
print("Dataset size:", len(dataset))
dataloader = DataLoader(dataset=dataset, batch_size=args.train_batch_size, sampler=RandomSampler(dataset), drop_last=True, collate_fn=collate_func)
for epi in range(n_epochs):
print("=================== EPOCH", epi, "===================")
for ib, documents in enumerate(dataloader):
Timer = {}
T1 = time.time()
log_obj = {}
bodies = [" ".join(doc.split(" ")[:300]) for doc in documents]
# We run tokenization in the background, as it is BERT tokenization only used after the summarizer has run. Saves about 5% of time.
thread1 = threading.Thread(target=background_tokenizer, args=(bodies, my_queue))
# bodies_bert_tokenized = [bert_tokenizer.enncode(body) for body in bodies] # This is the not background version
thread1.start()
T2 = time.time()
Timer["preprocessing_starting"] = T2-T1
# T1b = time.time()
sampled_summaries, sampled_logprobs, sampled_tokens, input_past, sampled_end_idxs = summarizer.decode_batch(bodies, max_output_length=args.max_output_length, return_logprobs=True, sample=True)
T3 = time.time()
Timer["generator_sampled"] = T3-T2
with torch.no_grad():
argmax_summaries, argmax_end_idxs = summarizer.decode_batch(bodies, max_output_length=args.max_output_length, input_past=input_past)
T4 = time.time()
Timer["generator_argmax"] = T4-T3
selected_logprobs = torch.sum(sampled_logprobs, dim=1)
batch_size, seq_length = sampled_logprobs.shape
# We join it here, saying the tokenization that's been running in the background should be done by now.
thread1.join()
bodies_bert_tokenized = my_queue.get()
scores_track = {}
total_sampled_scores = torch.FloatTensor([0.0] * batch_size).to(args.device)
total_argmax_scores = torch.FloatTensor([0.0] * batch_size).to(args.device)
for scorer in scorers:
T = time.time()
sampled_scores, extra = scorer['model'].score(sampled_summaries, bodies, bodies_tokenized=bodies_bert_tokenized, extra=None, lengths=sampled_end_idxs)
sampled_scores = torch.FloatTensor(sampled_scores).to(args.device)
argmax_scores, _ = scorer['model'].score(argmax_summaries, bodies, bodies_tokenized=bodies_bert_tokenized, extra=extra, lengths=argmax_end_idxs)
argmax_scores = torch.FloatTensor(argmax_scores).to(args.device)
Timer["scores_"+scorer['name']] = time.time()-T
total_sampled_scores += (scorer['sign'])*(scorer['importance'])*sampled_scores
total_argmax_scores += (scorer['sign'])*(scorer['importance'])*argmax_scores
log_obj[scorer['name']+"_score"] = sampled_scores.mean().item()
scores_track[scorer['name']+"_scores"] = sampled_scores
T5 = time.time()
Timer['all_scores'] = T5-T4
Loss = torch.mean((total_argmax_scores - total_sampled_scores) * selected_logprobs)
if args.fp16:
with amp.scale_loss(Loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
Loss.backward()
T6 = time.time()
Timer['backward'] = T6-T5
if ib % args.optim_every == 0:
optimizer.step()
optimizer.zero_grad()
T7 = time.time()
Timer['optim'] = T7-T6
# log_obj['summary_nwords'] = int(np.mean([summ.count(" ")+1 for summ in sampled_summaries]))
avg_total = total_sampled_scores.mean().item()
total_score_history.append(avg_total)
log_obj['summary_nwords'] = int(np.mean(sampled_end_idxs))
log_obj['loss'] = Loss.item()
log_obj['total_score'] = avg_total
log_obj['count'] = batch_size
logplot.cache(log_obj, prefix="T_")
Tfinal = time.time()
Timer['total'] = Tfinal - T1
# print(Timer)
if (time.time()-time_save > args.save_every):
print("==========================================")
print(bodies[0])
print("-----------")
print(sampled_summaries[0])
print("-----------")
print("Total score:", total_sampled_scores[0].item())
for scorer in scorers:
print(scorer['name']+" score:", scores_track[scorer['name']+"_scores"][0].item())
print("-----------")
logplot.save(printing=True)
# print(Timer)
time_save = time.time()
print("==========================================")
if ckpt_every > 0 and len(total_score_history) > ckpt_lookback:
current_score = np.mean(total_score_history[-ckpt_lookback:])
if time.time()-time_ckpt > ckpt_every:
revert_ckpt = best_ckpt_score is not None and current_score < min(1.2*best_ckpt_score, 0.8*best_ckpt_score) # Could be negative or positive
print("================================== CKPT TIME, "+str(datetime.now())+" =================================")
print("Previous best:", best_ckpt_score)
print("Current Score:", current_score)
print("[CKPT] Am I reverting?", ("yes" if revert_ckpt else "no! BEST CKPT"))
if revert_ckpt:
summarizer.model.load_state_dict(torch.load(ckpt_file))
optimizer.load_state_dict(torch.load(ckpt_optimizer_file))
time_ckpt = time.time()
print("==============================================================================")
if best_ckpt_score is None or current_score > best_ckpt_score:
print("[CKPT] Saved new best at: %.3f %s" % (current_score, "["+str(datetime.now())+"]"))
best_ckpt_score = current_score
torch.save(summarizer.model.state_dict(), ckpt_file)
torch.save(optimizer.state_dict(), ckpt_optimizer_file)