-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
61 lines (40 loc) · 1.87 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def safe_div(num, denom):
if denom > 0:
return num / denom
else:
return 0
def compute_f1(predicted, gold, matched):
precision = safe_div(matched, predicted)
recall = safe_div(matched, gold)
f1 = safe_div(2 * precision * recall, precision + recall)
return precision, recall, f1
def evaluate(pred_res, gold_res, role_type_idxs):
rev_role_types = {v:k for k,v in role_type_idxs.items()}
gold_nums = {value:0.0 for key,value in rev_role_types.items()}
pred_nums = {value:0.0 for key,value in rev_role_types.items()}
correct_nums = {value:0.0 for key,value in rev_role_types.items()}
stats = {}
total_gold_nums, total_pred_nums, total_correct_nums = 0.0, 0.0, 0.0
unrelated_idx = role_type_idxs["unrelated object"]
# unrelated_idx = -1
for i in range(len(pred_res)):
pred_list_i = pred_res[i]
gold_list_i = gold_res[i]
for j in range(len(pred_list_i)):
# calculate total
if pred_list_i[j] != unrelated_idx:
total_pred_nums += 1
pred_nums[rev_role_types[pred_list_i[j]]] += 1
if gold_list_i[j] != unrelated_idx:
total_gold_nums += 1
gold_nums[rev_role_types[gold_list_i[j]]] += 1
if pred_list_i[j] == gold_list_i[j] and pred_list_i[j] != unrelated_idx and gold_list_i[j] != unrelated_idx:
correct_nums[rev_role_types[gold_list_i[j]]] += 1
total_correct_nums += 1
p, r, f = compute_f1(total_pred_nums, total_gold_nums, total_correct_nums)
for key in gold_nums:
pred, gold, matched = pred_nums[key], gold_nums[key], correct_nums[key]
pi, ri, fi = compute_f1(pred, gold, matched)
res_dict = {"p":pi, "r":ri, "f":fi}
stats.update({key: res_dict})
return stats, (p, r, f)