forked from Attn-to-FC/Attn-to-FC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bleu.py
122 lines (98 loc) · 3.2 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import sys
import pickle
import argparse
import re
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
from myutils import prep, drop, statusout, batch_gen, seq2sent, index2word
def fil(com):
ret = list()
for w in com:
if not '<' in w:
ret.append(w)
return ret
def bleu_so_far(refs, preds):
Ba = corpus_bleu(refs, preds)
B1 = corpus_bleu(refs, preds, weights=(1,0,0,0))
B2 = corpus_bleu(refs, preds, weights=(0,1,0,0))
B3 = corpus_bleu(refs, preds, weights=(0,0,1,0))
B4 = corpus_bleu(refs, preds, weights=(0,0,0,1))
Ba = round(Ba * 100, 2)
B1 = round(B1 * 100, 2)
B2 = round(B2 * 100, 2)
B3 = round(B3 * 100, 2)
B4 = round(B4 * 100, 2)
ret = ''
ret += ('for %s functions\n' % (len(preds)))
ret += ('Ba %s\n' % (Ba))
ret += ('B1 %s\n' % (B1))
ret += ('B2 %s\n' % (B2))
ret += ('B3 %s\n' % (B3))
ret += ('B4 %s\n' % (B4))
return ret
def re_0002(i):
# split camel case and remove special characters
tmp = i.group(0)
if len(tmp) > 1:
if tmp.startswith(' '):
return tmp
else:
return '{} {}'.format(tmp[0], tmp[1])
else:
return ' '.format(tmp)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('input', type=str, default=None)
parser.add_argument('--data', dest='dataprep', type=str, default='/nfs/projects/attn-to-fc/data/standard/output')
parser.add_argument('--outdir', dest='outdir', type=str, default='/nfs/projects/attn-to-fc/data/outdir')
parser.add_argument('--challenge', action='store_true', default=False)
parser.add_argument('--obfuscate', action='store_true', default=False)
parser.add_argument('--sbt', action='store_true', default=False)
args = parser.parse_args()
outdir = args.outdir
dataprep = args.dataprep
input_file = args.input
challenge = args.challenge
obfuscate = args.obfuscate
sbt = args.sbt
if challenge:
dataprep = '../data/challengeset/output'
if obfuscate:
dataprep = '../data/obfuscation/output'
if sbt:
dataprep = '../data/sbt/output'
if input_file is None:
print('Please provide an input file to test with --input')
exit()
sys.path.append(dataprep)
import tokenizer
prep('preparing predictions list... ')
preds = dict()
predicts = open(input_file, 'r')
for c, line in enumerate(predicts):
(fid, pred) = line.split('\t')
fid = int(fid)
pred = pred.split()
pred = fil(pred)
preds[fid] = pred
predicts.close()
drop()
re_0001_ = re.compile(r'([^a-zA-Z0-9 ])|([a-z0-9_][A-Z])')
refs = list()
newpreds = list()
d = 0
targets = open('%s/coms.test' % (dataprep), 'r')
for line in targets:
(fid, com) = line.split(',')
fid = int(fid)
com = com.split()
com = fil(com)
if len(com) < 1:
continue
try:
newpreds.append(preds[fid])
except Exception as ex:
#newpreds.append([])
continue
refs.append([com])
print('final status')
print(bleu_so_far(refs, newpreds))