-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
125 lines (100 loc) · 3.51 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
# !-*-coding:utf-8 -*-
import json
import numpy as np
import sys
sys.path.append("metric")
from metric.smooth_bleu import codenn_smooth_bleu
from metric.meteor.meteor import Meteor
from metric.rouge.rouge import Rouge
from metric.cider.cider import Cider
import warnings
import argparse
import logging
import prettytable as pt
warnings.filterwarnings('ignore')
logging.basicConfig(format='[%(asctime)s - %(levelname)s - %(name)s] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
def calc_recall(pred, tgt):
count = 0
for prd in pred:
if prd in tgt:
count += 1
return count / len(tgt)
def calc_prec(pred, tgt):
count = 0
for prd in pred:
if prd in tgt:
count += 1
return count / len(pred)
def em_prec_recall(refs, preds):
EM = []
prec = []
recall = []
for r, p in zip(refs, preds):
EM.append(r[0] == p)
prec.append(calc_prec(p, r[0]))
recall.append(calc_recall(p, r[0]))
print("EM: ", round(np.mean(EM)*100, 2))
print("Precision: ", round(np.mean(prec)*100, 2))
print("Recall: ", round(np.mean(recall) * 100, 2))
def Commitbleus(refs, preds):
r_str_list = []
p_str_list = []
for r, p in zip(refs, preds):
if len(r[0]) == 0 or len(p) == 0:
continue
r_str_list.append([" ".join([str(token_id) for token_id in r[0]])])
p_str_list.append(" ".join([str(token_id) for token_id in p]))
try:
bleu_list = codenn_smooth_bleu(r_str_list, p_str_list)
except:
bleu_list = [0, 0, 0, 0]
codenn_bleu = bleu_list[0]
B_Norm = round(codenn_bleu, 4)
return B_Norm
def read_to_list(filename, index):
f = open(filename, 'r',encoding="utf-8")
res = []
for row in f:
if index:
(rid, text) = row.split('\t')
res.append(text.lower().split())
else:
res.append(row.lower().split())
return res
def metetor_rouge_cider(refs, preds):
refs_dict = {}
preds_dict = {}
for i in range(len(preds)):
preds_dict[i] = [" ".join(preds[i])]
refs_dict[i] = [" ".join(refs[i][0])]
score_Meteor, scores_Meteor = Meteor().compute_score(refs_dict, preds_dict)
print("Meteor: ", round(score_Meteor*100,2))
score_Rouge, scores_Rouge = Rouge().compute_score(refs_dict, preds_dict)
print("Rouge-L: ", round(score_Rouge*100,2))
score_Cider, scores_Cider = Cider().compute_score(refs_dict, preds_dict)
print("Cider: ",round(score_Cider,2))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prd_dir", default=None, type=str,
help="File dir to read predict msg")
parser.add_argument("--gold_dir", default=None, type=str,
help="File dir to read gold msg")
parser.add_argument("--prd_index", action='store_true',
help="Contain row id in file predict lines")
parser.add_argument("--gold_index", action='store_true',
help="Contain row id in file gold lines")
args = parser.parse_args()
refs = read_to_list(args.gold_dir, args.gold_index)
refs = [[t] for t in refs]
preds = read_to_list(args.prd_dir, args.prd_index)
bleus_score = Commitbleus(refs, preds)
print("Refs: ", len(refs))
print("Preds: ", len(preds))
print("BLEU: %.2f"%bleus_score)
metetor_rouge_cider(refs, preds)
em_prec_recall(refs, preds)
if __name__ == '__main__':
main()