-
Notifications
You must be signed in to change notification settings - Fork 43
/
gen_candidate.py
117 lines (112 loc) · 5.57 KB
/
gen_candidate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from transformers import BartForConditionalGeneration, BartTokenizer, PegasusTokenizer, PegasusForConditionalGeneration
import torch
import sys
import argparse
from typing import List
def generate_summaries_cnndm(args):
device = f"cuda:{args.gpuid}"
mname = "facebook/bart-large-cnn"
model = BartForConditionalGeneration.from_pretrained(mname).to(device)
model.eval()
tokenizer = BartTokenizer.from_pretrained(mname)
max_length = 140
min_length = 55
count = 1
bsz = 8
with open(args.src_dir) as source, open(args.tgt_dir, 'w') as fout:
sline = source.readline().strip().lower()
slines = [sline]
for sline in source:
if count % 100 == 0:
print(count, flush=True)
if count % bsz == 0:
with torch.no_grad():
dct = tokenizer.batch_encode_plus(slines, max_length=1024, return_tensors="pt", pad_to_max_length=True, truncation=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_return_sequences=16, num_beam_groups=16, diversity_penalty=1.0, num_beams=16,
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
length_penalty=2.0,
early_stopping=True,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
hypothesis = hypothesis.replace("\n", " ")
fout.write(hypothesis + '\n')
fout.flush()
slines = []
sline = sline.strip().lower()
if len(sline) == 0:
sline = " "
slines.append(sline)
count += 1
if slines != []:
with torch.no_grad():
dct = tokenizer.batch_encode_plus(slines, max_length=1024, return_tensors="pt", pad_to_max_length=True, truncation=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_return_sequences=16, num_beam_groups=16, diversity_penalty=1.0, num_beams=16,
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
length_penalty=2.0,
early_stopping=True,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
hypothesis = hypothesis.replace("\n", " ")
fout.write(hypothesis + '\n')
fout.flush()
def generate_summaries_xsum(args):
device = f"cuda:{args.gpuid}"
mname = "google/pegasus-xsum"
model = PegasusForConditionalGeneration.from_pretrained(mname).to(device)
model.eval()
tok = PegasusTokenizer.from_pretrained(mname)
count = 1
bsz = 2
with open(args.src_dir) as source, open(args.tgt_dir, 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for (i, sline) in enumerate(source):
if count % 10 == 0:
print(count)
if count % bsz == 0:
with torch.no_grad():
batch = tok.prepare_seq2seq_batch(src_texts=slines, return_tensors="pt").to(device)
gen = model.generate(**batch, num_return_sequences=128, num_beam_groups=16, diversity_penalty=0.1, num_beams=128, length_penalty=0.6)
dec: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
dec = [dec[i] for i in range(len(dec)) if i % 8 == 0]
for hypothesis in dec:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
sline = sline.strip()
if len(sline) == 0:
sline = " "
slines.append(sline)
count += 1
if slines != []:
with torch.no_grad():
batch = tok.prepare_seq2seq_batch(src_texts=slines, return_tensors="pt").to(device)
gen = model.generate(**batch, num_return_sequences=128, num_beam_groups=16, diversity_penalty=0.1, num_beams=128, length_penalty=0.6)
dec: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
dec = [dec[i] for i in range(len(dec)) if i % 8 == 0]
for hypothesis in dec:
fout.write(hypothesis + '\n')
fout.flush()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Parameters')
parser.add_argument("--gpuid", type=int, default=0, help="gpu id")
parser.add_argument("--src_dir", type=str, help="source file")
parser.add_argument("--tgt_dir", type=str, help="target file")
parser.add_argument("--dataset", type=str, default="cnndm", help="dataset")
args = parser.parse_args()
if args.dataset == "cnndm":
generate_summaries_cnndm(args)
elif args.dataset == "xsum":
generate_summaries_xsum(args)