forked from nlpyang/geval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmeta_eval_summeval.py
77 lines (60 loc) · 2.47 KB
/
meta_eval_summeval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from prettytable import PrettyTable
from scipy.stats import spearmanr, pearsonr, kendalltau
import json
import re
import argparse
def calculate_correlation(pred_score, human_score, result):
assert len(pred_score) == len(human_score)
if (len(result) == 0):
result = {'pearson': 0, 'spearman': 0, 'kendalltau': 0}
result['pearson'] += pearsonr(pred_score, human_score)[0]
result['spearman'] += spearmanr(pred_score, human_score)[0]
result['kendalltau'] += kendalltau(pred_score, human_score)[0]
return result
def print_correlations(result, n):
table = PrettyTable(['Pearson', 'Spearman', 'Kendall'])
if (n == 0):
n = 1
table.add_row(
[round(result['pearson'] / n, 4), round(result['spearman'] / n, 4), round(result['kendalltau'] / n, 4)])
print(table)
def parse_output(output):
matched = re.search("^ ?([\d\.]+)", output)
if (matched):
try:
score = float(matched.group(1))
except:
score = 0
else:
score = 0
return score
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_fp', type=str, default='results/gpt4_rel_detailed.json')
parser.add_argument('--dimension', type=str, default='relevance')
args = parser.parse_args()
jobj = json.load(open(args.input_fp))
pred_scores, human_scores = {}, {}
print("Calculating correlation for G-Eval")
for item in jobj:
doc_id = item["doc_id"]
if (doc_id not in pred_scores):
pred_scores[doc_id] = []
human_scores[doc_id] = []
all_responses = item["all_responses"]
all_scores = [parse_output(x) for x in all_responses]
score = sum(all_scores) / len(all_scores)
pred_scores[doc_id].append(score)
human_scores[doc_id].append(item['scores'][args.dimension])
print('len(pred_scores): {}'.format(len(pred_scores)))
print('len(human_scores): {}'.format(len(human_scores)))
results = {'pearson': 0, 'spearman': 0, 'kendalltau': 0}
d_ctr = 0
for doc_id in pred_scores:
pred_scores_doc = pred_scores[doc_id]
human_scores_doc = human_scores[doc_id]
if (len(set(human_scores_doc)) <= 1) or (len(set(pred_scores_doc)) <= 1):
continue
results = calculate_correlation(pred_scores_doc, human_scores_doc, results)
d_ctr += 1
print_correlations(results, n=d_ctr)