-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathhard_supervision_functions.py
154 lines (127 loc) · 5.37 KB
/
hard_supervision_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import random
from collections import defaultdict
def get_ents2reltime(kg_file):
with open(kg_file) as f:
facts = f.read().strip().split('\n')
facts = [f.strip().split('\t') for f in facts]
facts_di = defaultdict(set)
for f in facts:
facts_di[(f[0], f[2])].update( [(f[0], f[1], f[2],f[3], f[4])] )
#facts_di[(f[0], f[2])].update( [(f[1], f[3], f[4])] )
return facts_di
def get_ent2triplet(kg_file):
with open(kg_file) as f:
facts = f.read().strip().split('\n')
facts = [f.strip().split('\t') for f in facts]
facts_di = defaultdict(set)
for f in facts:
facts_di[(f[0])].update( [(f[0], f[1], f[2],f[3], f[4])] )
facts_di[(f[2])].update( [(f[0], f[1], f[2],f[3], f[4])] )
#facts_di[(f[0], f[2])].update( [(f[1], f[3], f[4])] )
return facts_di
def get_event2time(kg_file):
with open(kg_file) as f:
facts = f.read().strip().split('\n')
facts = [f.strip().split('\t') for f in facts]
facts_di = defaultdict(set)
for f in facts:
if f[1] == 'P793' and f[2] == 'Q1190554':
facts_di[f[0]].update( [(f[0], f[1], f[2],f[3], f[4])] )
#facts_di[f[0]].update( [( f[1], f[3], f[4])] )
return facts_di
def get_ent_time2rel_ent(kg_file):
with open(kg_file) as f:
facts = f.read().strip().split('\n')
facts = [f.strip().split('\t') for f in facts]
facts_di = defaultdict(set)
for f in facts:
facts_di[(f[0], int(f[3]))].update( [(f[0], f[1], f[2],f[3], f[4])] )
facts_di[(f[0], int(f[4]))].update( [(f[0], f[1], f[2], f[3], f[4])] )
facts_di[(f[2], int(f[3]))].update( [(f[0], f[1], f[2], f[3], f[4])] )
facts_di[(f[2], int(f[4]))].update( [(f[0], f[1], f[2],f[3], f[4])] )
return facts_di
def get_ent_time2rel_ent(kg_file):
with open(kg_file) as f:
facts = f.read().strip().split('\n')
facts = [f.strip().split('\t') for f in facts]
facts_di = defaultdict(lambda : defaultdict(set))
for f in facts:
facts_di[f[0]][int(f[3])].update( [(f[0], f[1], f[2], f[3], f[4])] )
facts_di[f[0]][int(f[4])].update( [(f[0], f[1], f[2],f[3], f[4])] )
facts_di[f[2]][int(f[3])].update( [(f[0], f[1], f[2],f[3], f[4])] )
facts_di[f[2]][int(f[4])].update( [(f[0], f[1], f[2], f[3], f[4])] )
return facts_di
def get_kg_facts_for_datapoint(e, e2tr, e2rt, et2re, event2time, thresh, time_delta=10):
keys = e['annotation'].keys()
if ('head' in keys) and ('tail' in keys) and ('tail2' in keys):
head, tail, tail2 = e['annotation']['head'], e['annotation']['tail'], e['annotation']['tail2']
return e2rt[(head, tail)].union(e2rt[(head, tail2)].union(e2rt[(tail, tail2)]))
elif ('head' in keys) and ('tail' in keys):
head, tail = e['annotation']['head'], e['annotation']['tail']
return e2rt[(head, tail)]
elif ('event_head' in keys) and ('tail' in keys):
# pdb.set_trace()
event_occ = event2time[e['annotation']['event_head']]
if len(event_occ) > 0:
event = next(iter(event_occ))
tail_facts = [f for time, facts in et2re[e['annotation']['tail']].items() for f in facts]
#"""
if len(event_occ) > 0:
tail_facts = [f for f in tail_facts if (int(f[3]) >= (int(event[3]) - time_delta)) and (int(f[4]) <= (int(event[4]) + time_delta))]
#tail_facts = [f for f in tail_facts if (int(f[0]) >= (int(event[0]) - time_delta)) and (int(f[1]) <= (int(event[1]) + time_delta))]
tail_facts = random.sample(tail_facts, thresh - 1) if len(tail_facts) > (thresh - 1) else tail_facts
#"""
return set(list(event_occ) + tail_facts)
elif 'time' in keys:
ent = e['annotation']['head'] if 'head' in keys else e['annotation']['tail']
return et2re[ent][int(e['annotation']['time'])]
else:
if ('head' in keys):
ent = e['annotation']['head']
else :
ent = e['annotation']['tail']
return e2tr[ent]
def append_time_to_question(question, facts):
if facts:
q = ', '+str(facts[0])+', '+str(facts[1])
question['annotation']['time1'] = facts[0]
question['annotation']['time2'] = facts[1]
question['paraphrases'][0] = question['paraphrases'][0] + q
question['question'] += q
question['template'] = question['template'] + ', {time1}, {time2}'
def retrieve_time_for_question(d, facts, corrupt_p):
whether_to_corrupt = [0, 1]
corrupt_probs = [(1-corrupt_p), corrupt_p]
facts = list(facts)
if len(facts)> 0:
d['fact'] = []
for f in facts:
#probability of corruption during QA
if random.choices(whether_to_corrupt, corrupt_probs,k=1)[0] == 0:
d['fact'].append([f[3], f[4]])
else:
d['fact'] = []
return
def add_facts_to_data(data, corrupt_p, fuse, e2tr, e2rt, et2re, event2time, thresh=5):
for d in data:
facts = get_kg_facts_for_datapoint(d, e2tr, e2rt, et2re, event2time, thresh)
facts = sorted(facts, key=lambda x: x[3])
#remove `no_time' if corrupted
facts = [x for x in facts if x != 9620]
#TempoQR-att appends to the question
if fuse == 'att':
append_time_to_question(d, facts)
else:
retrieve_time_for_question(d, facts, corrupt_p)
return data
def retrieve_times(kg_file, dataset_name, data, corrupt_p, fuse):
#kg_file could involve a corrupt TKG
kg_file = f'data/{dataset_name}/kg/'+kg_file
#collecting possible combinations of annotated entities/timestamps
e2tr = get_ent2triplet(kg_file)
e2rt = get_ents2reltime(kg_file)
et2re = get_ent_time2rel_ent(kg_file)
event2time = get_event2time(kg_file)
#collect all the question-specific timestmaps
data = add_facts_to_data(data, corrupt_p, fuse, e2tr, e2rt, et2re, event2time)
return data