-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_labels_from_annotations.py
91 lines (80 loc) · 3.09 KB
/
get_labels_from_annotations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python
__author__ = 'jesse'
import argparse
import os
import pickle
import numpy as np
import sys
from sklearn.metrics import cohen_kappa_score
from scipy.stats import mode
def main():
# Convert flags to local variables.
infile = FLAGS_infile
indir = FLAGS_indir
outfile = FLAGS_outfile
# Read in predicate pickle.
with open(os.path.join(indir, "predicates.pickle"), 'rb') as f:
predicates = pickle.load(f)
# Read in annotations.
labels = {}
annotators = []
with open(infile, 'r') as f:
lines = f.readlines()
h = lines[0].strip().split(',')
for l in lines[1:]:
parts = l.strip().split(',')
sidx = int(parts[h.index('sidx')])
oidx = int(parts[h.index('oidx')])
if oidx not in labels:
labels[oidx] = {}
if sidx not in annotators:
annotators.append(sidx)
unrecognized_preds = [p for p in parts[2:] if p not in predicates]
if len(unrecognized_preds) > 0:
sys.exit("ERROR: unrecognized predicates: " + str(unrecognized_preds))
labels[oidx][sidx] = [1 if p in parts[2:] else 0 for p in predicates]
# Calculate pairwise kappa values.
kappas = []
ks = 0
for idx in range(len(annotators)):
aidx = annotators[idx]
for jdx in range(idx + 1, len(annotators)):
ajdx = annotators[jdx]
fi = []
fj = []
for oidx in labels:
if aidx in labels[oidx].keys() and ajdx in labels[oidx].keys():
fi.extend(labels[oidx][aidx])
fj.extend(labels[oidx][ajdx])
if len(fi) > 0:
ka = cohen_kappa_score(fi, fj)
kappas.append((aidx, ajdx, ka))
print "annotators " + str(aidx) + ", " + str(ajdx) + ": k=" + str(ka)
ks += ka
print "avg kappa: " + str(ks / len(kappas))
# Print disagreements and decision directions.
v_labels = {} # oidx, pidx, {0, 1}
for oidx in labels:
v_labels[oidx] = [0.5 for _ in range(len(predicates))]
for pidx in range(len(predicates)):
votes = [labels[oidx][sidx][pidx] for sidx in labels[oidx].keys()]
m = mode(votes)[0][0]
if m != np.mean(votes):
print ("oidx " + str(oidx) + ", predicate " + predicates[pidx] +
" disagreement: " + str(votes) + " -> " + str(m))
v_labels[oidx][pidx] = m
# Write annotation pickle outfile.
with open(outfile, 'wb') as f:
pickle.dump(v_labels, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--infile', type=str, required=True,
help="annotations csv")
parser.add_argument('--indir', type=str, required=True,
help="data directory")
parser.add_argument('--outfile', type=str, required=True,
help="labels pickle")
args = parser.parse_args()
for k, v in vars(args).items():
globals()['FLAGS_%s' % k] = v
main()