-
Notifications
You must be signed in to change notification settings - Fork 16
/
data_augmentation.py
executable file
·363 lines (353 loc) · 19.9 KB
/
data_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
from utils import *
from copy import deepcopy
import random
# Generate the mask based on the valences and adjacent matrix so far
# For a (node_in_focus, neighbor, edge_type) to be valid, neighbor's color < 2 and
# there is no edge so far between node_in_focus and neighbor and it satisfy the valence constraint
# and node_in_focus != neighbor
def generate_mask(valences, adj_mat, color, real_n_vertices, node_in_focus, check_overlap_edge, new_mol):
edge_type_mask=[]
edge_mask=[]
for neighbor in range(real_n_vertices):
if neighbor != node_in_focus and color[neighbor] < 2 and \
not check_adjacent_sparse(adj_mat, node_in_focus, neighbor)[0]:
min_valence = min(valences[node_in_focus], valences[neighbor], 3)
# Check whether two cycles have more than two overlap edges here
# the neighbor color = 1 and there are left valences and
# adding that edge will not cause overlap edges.
if check_overlap_edge and min_valence > 0 and color[neighbor] == 1:
# attempt to add the edge
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0])
# Check whether there are two cycles having more than two overlap edges
ssr = Chem.GetSymmSSSR(new_mol)
overlap_flag = False
for idx1 in range(len(ssr)):
for idx2 in range(idx1+1, len(ssr)):
if len(set(ssr[idx1]) & set(ssr[idx2])) > 2:
overlap_flag=True
# remove that edge
new_mol.RemoveBond(int(node_in_focus), int(neighbor))
if overlap_flag:
continue
for v in range(min_valence):
assert v < 3
edge_type_mask.append((node_in_focus, neighbor, v))
# there might be an edge between node in focus and neighbor
if min_valence > 0:
edge_mask.append((node_in_focus, neighbor))
return edge_type_mask, edge_mask
# when a new edge is about to be added, we generate labels based on ground truth
# if an edge is in ground truth and has not been added to incremental adj yet, we label it as positive
def generate_label(ground_truth_graph, incremental_adj, node_in_focus, real_neighbor, real_n_vertices, params):
edge_type_label=[]
edge_label=[]
for neighbor in range(real_n_vertices):
adjacent, edge_type = check_adjacent_sparse(ground_truth_graph, node_in_focus, neighbor)
incre_adjacent, incre_edge_type = check_adjacent_sparse(incremental_adj, node_in_focus, neighbor)
if not params["label_one_hot"] and adjacent and not incre_adjacent:
assert edge_type < 3
edge_type_label.append((node_in_focus, neighbor, edge_type))
edge_label.append((node_in_focus, neighbor))
elif params["label_one_hot"] and adjacent and not incre_adjacent and neighbor==real_neighbor:
edge_type_label.append((node_in_focus, neighbor, edge_type))
edge_label.append((node_in_focus, neighbor))
return edge_type_label, edge_label
# add a incremental adj with one new edge
def genereate_incremental_adj(last_adj, node_in_focus, neighbor, edge_type):
# copy last incremental adj matrix
new_adj= deepcopy(last_adj)
# Add a new edge into it
new_adj[node_in_focus].append((neighbor, edge_type))
new_adj[neighbor].append((node_in_focus, edge_type))
return new_adj
def update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type, edge_type_masks, valences, incremental_adj_mat,
color, real_n_vertices, graph, edge_type_labels, local_stop, edge_masks, edge_labels, local_stop_label, params,
check_overlap_edge, new_mol, up_to_date_adj_mat,keep_prob):
# check whether to keep this transition or not
if params["sample_transition"] and random.random()> keep_prob:
return
# record the current node in focus
node_sequence.append(node_in_focus)
# generate mask based on current situation
edge_type_mask, edge_mask=generate_mask(valences, up_to_date_adj_mat,
color,real_n_vertices, node_in_focus, check_overlap_edge, new_mol)
edge_type_masks.append(edge_type_mask)
edge_masks.append(edge_mask)
if not local_stop_label:
# generate the label based on ground truth graph
edge_type_label, edge_label=generate_label(graph, up_to_date_adj_mat, node_in_focus, neighbor,real_n_vertices, params)
edge_type_labels.append(edge_type_label)
edge_labels.append(edge_label)
else:
edge_type_labels.append([])
edge_labels.append([])
# update local stop
local_stop.append(local_stop_label)
# Calculate distance using bfs from the current node to all other node
distances = bfs_distance(node_in_focus, up_to_date_adj_mat)
distances = [(start, node, params["truncate_distance"]) if d > params["truncate_distance"] else (start, node, d) for start, node, d in distances]
distance_to_others.append(distances)
# Calculate the overlapped edge mask
overlapped_edge_features.append(get_overlapped_edge_feature(edge_mask, color, new_mol))
# update the incremental adj mat at this step
incremental_adj_mat.append(deepcopy(up_to_date_adj_mat))
def construct_incremental_graph(dataset, edges, max_n_vertices, real_n_vertices, node_symbol, params, is_training_data, initial_idx=0): # FI changed
# avoid calculating this if it is just for generating new molecules for speeding up
if params["generation"] and is_training_data: # FI changed
return [], [], [], [], [], [], [], [], []
# avoid the initial index is larger than real_n_vertices:
if initial_idx >= real_n_vertices:
initial_idx=0
# Maximum valences for each node
valences=get_initial_valence([np.argmax(symbol) for symbol in node_symbol], dataset)
# Add backward edges
edges_bw=[(dst, edge_type, src) for src, edge_type, dst in edges]
edges=edges+edges_bw
# Construct a graph object using the edges
graph=defaultdict(list)
for src, edge_type, dst in edges:
graph[src].append((dst, edge_type))
# Breadth first search over the molecule
# color 0: have not found 1: in the queue 2: searched already
color = [0] * max_n_vertices
color[initial_idx] = 1
queue=deque([initial_idx])
# create a adj matrix without any edges
up_to_date_adj_mat=defaultdict(list)
# record incremental adj mat
incremental_adj_mat=[]
# record the distance to other nodes at the moment
distance_to_others=[]
# soft constraint on overlapped edges
overlapped_edge_features=[]
# the exploration order of the nodes
node_sequence=[]
# edge type masks for nn predictions at each step
edge_type_masks=[]
# edge type labels for nn predictions at each step
edge_type_labels=[]
# edge masks for nn predictions at each step
edge_masks=[]
# edge labels for nn predictions at each step
edge_labels=[]
# local stop labels
local_stop=[]
# record the incremental molecule
new_mol = Chem.MolFromSmiles('')
new_mol = Chem.rdchem.RWMol(new_mol)
# Add atoms
add_atoms(new_mol, sample_node_symbol([node_symbol], [len(node_symbol)], dataset)[0], dataset)
# calculate keep probability
sample_transition_count= real_n_vertices + len(edges)/2
keep_prob= float(sample_transition_count)/((real_n_vertices + len(edges)/2) * params["bfs_path_count"]) # to form a binomial distribution
while len(queue) > 0:
node_in_focus=queue.popleft()
current_adj_list=graph[node_in_focus]
# sort (canonical order) it or shuffle (random order) it
if not params["path_random_order"]:
current_adj_list=sorted(current_adj_list)
else:
random.shuffle(current_adj_list)
for neighbor, edge_type in current_adj_list:
# Add this edge if the color of neighbor node is not 2
if color[neighbor]<2:
update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type,
edge_type_masks, valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, False, params, params["check_overlap_edge"], new_mol,
up_to_date_adj_mat,keep_prob)
# Add the edge and obtain a new adj mat
up_to_date_adj_mat=genereate_incremental_adj(
up_to_date_adj_mat, node_in_focus, neighbor, edge_type)
# suppose the edge is selected and update valences after adding the
valences[node_in_focus]-=(edge_type + 1)
valences[neighbor]-=(edge_type + 1)
# update the incremental mol
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[edge_type])
# Explore neighbor nodes
if color[neighbor]==0:
queue.append(neighbor)
color[neighbor]=1
# local stop here. We move on to another node for exploration or stop completely
update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, None, None, edge_type_masks,
valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, True, params, params["check_overlap_edge"], new_mol, up_to_date_adj_mat,keep_prob)
color[node_in_focus]=2
return incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features
#freq_dict = pickle.load(open("./freq_dict_zinc_250k_smarts_parallel.pkl",'rb'))
# Generate the frequences based on the valences and adjacent matrix so far
# For a (node_in_focus, neighbor, edge_type) to be valid, neighbor's color < 2 and
# there is no edge so far between node_in_focus and neighbor and it satisfy the valence constraint
# and node_in_focus != neighbor
def generate_frequencies(valences, adj_mat, color, real_n_vertices, node_in_focus, check_overlap_edge, new_mol, freq_dict):
transition_frequences = []
transition_frequences_edge = []
for neighbor in range(real_n_vertices):
if neighbor != node_in_focus and color[neighbor] < 2 and \
not check_adjacent_sparse(adj_mat, node_in_focus, neighbor)[0]:
min_valence = min(valences[node_in_focus], valences[neighbor], 3)
# Check whether two cycles have more than two overlap edges here
# the neighbor color = 1 and there are left valences and
# adding that edge will not cause overlap edges.
if check_overlap_edge and min_valence > 0 and color[neighbor] == 1:
# attempt to add the edge
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0])
# Check whether there are two cycles having more than two overlap edges
ssr = Chem.GetSymmSSSR(new_mol)
overlap_flag = False
for idx1 in range(len(ssr)):
for idx2 in range(idx1+1, len(ssr)):
if len(set(ssr[idx1]) & set(ssr[idx2])) > 2:
overlap_flag=True
# remove that edge
new_mol.RemoveBond(int(node_in_focus), int(neighbor))
if overlap_flag:
continue
score = 0
for v in range(min_valence):
assert v < 3
# Get transition compound at lookup frequency for each possible edge type
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[v])
radius = 5
submol_smiles = ""
while submol_smiles == "":
env = Chem.FindAtomEnvironmentOfRadiusN(new_mol,radius,int(node_in_focus))
amap={}
submol=Chem.PathToSubmol(new_mol,env,atomMap=amap)
submol_smiles = Chem.MolToSmarts(submol)
radius -= 1
if submol_smiles in freq_dict:
score += freq_dict[submol_smiles]
transition_frequences_edge.append((neighbor, v, freq_dict[submol_smiles]))
else:
score += 0
transition_frequences_edge.append((neighbor, v, 0))
score /= (min_valence+1)
new_mol.RemoveBond(int(node_in_focus), int(neighbor))
transition_frequences.append((neighbor, score))
# there might be an edge between node in focus and neighbor
if min_valence > 0:
#edge_mask.append((node_in_focus, neighbor))
continue
return transition_frequences, transition_frequences_edge
def update_one_step_freqs(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type, edge_type_masks, valences, incremental_adj_mat,
color, real_n_vertices, graph, edge_type_labels, local_stop, edge_masks, edge_labels, local_stop_label, params,
check_overlap_edge, new_mol, up_to_date_adj_mat,keep_prob, new_compound_frequences, new_compound_frequences_edge, freq_dict):
# check whether to keep this transition or not
if params["sample_transition"] and random.random()> keep_prob:
return
# record the current node in focus
node_sequence.append(node_in_focus)
# generate mask based on current situation
edge_type_mask, edge_mask=generate_mask(valences, up_to_date_adj_mat,
color,real_n_vertices, node_in_focus, check_overlap_edge, new_mol)
edge_type_masks.append(edge_type_mask)
edge_masks.append(edge_mask)
# generate transition frequencies based on current situation
transition_frequences, transition_frequences_edge = generate_frequencies(valences, up_to_date_adj_mat,
color,real_n_vertices, node_in_focus, check_overlap_edge, new_mol, freq_dict)
new_compound_frequences.append(transition_frequences)
new_compound_frequences_edge.append(transition_frequences_edge)
if not local_stop_label:
# generate the label based on ground truth graph
edge_type_label, edge_label=generate_label(graph, up_to_date_adj_mat, node_in_focus, neighbor,real_n_vertices, params)
edge_type_labels.append(edge_type_label)
edge_labels.append(edge_label)
else:
edge_type_labels.append([])
edge_labels.append([])
# update local stop
local_stop.append(local_stop_label)
# Calculate distance using bfs from the current node to all other node
distances = bfs_distance(node_in_focus, up_to_date_adj_mat)
distances = [(start, node, params["truncate_distance"]) if d > params["truncate_distance"] else (start, node, d) for start, node, d in distances]
distance_to_others.append(distances)
# Calculate the overlapped edge mask
overlapped_edge_features.append(get_overlapped_edge_feature(edge_mask, color, new_mol))
# update the incremental adj mat at this step
incremental_adj_mat.append(deepcopy(up_to_date_adj_mat))
def construct_incremental_graph_freqs(dataset, edges, max_n_vertices, real_n_vertices, node_symbol, params, is_training_data, freq_dict, initial_idx=0):
# avoid calculating this if it is just for generating new molecules for speeding up
if params["generation"] and is_training_data:
return [], [], [], [], [], [], [], [], [], [], []
# avoid the initial index is larger than real_n_vertices:
if initial_idx >= real_n_vertices:
initial_idx=0
# Maximum valences for each node
valences=get_initial_valence([np.argmax(symbol) for symbol in node_symbol], dataset)
# Add backward edges
edges_bw=[(dst, edge_type, src) for src, edge_type, dst in edges]
edges=edges+edges_bw
# Construct a graph object using the edges
graph=defaultdict(list)
for src, edge_type, dst in edges:
graph[src].append((dst, edge_type))
# Breadth first search over the molecule
# color 0: have not found 1: in the queue 2: searched already
color = [0] * max_n_vertices
color[initial_idx] = 1
queue=deque([initial_idx])
# create a adj matrix without any edges
up_to_date_adj_mat=defaultdict(list)
# record incremental adj mat
incremental_adj_mat=[]
# record the distance to other nodes at the moment
distance_to_others=[]
# soft constraint on overlapped edges
overlapped_edge_features=[]
# the exploration order of the nodes
node_sequence=[]
# edge type masks for nn predictions at each step
edge_type_masks=[]
# edge type labels for nn predictions at each step
edge_type_labels=[]
# edge masks for nn predictions at each step
edge_masks=[]
# edge labels for nn predictions at each step
edge_labels=[]
# local stop labels
local_stop=[]
# record frequences of if allowed bonds made (used to judge danger of a position)
new_compound_frequences = []
new_compound_frequences_edge = []
# record the incremental molecule
new_mol = Chem.MolFromSmiles('')
new_mol = Chem.rdchem.RWMol(new_mol)
# Add atoms
add_atoms(new_mol, sample_node_symbol([node_symbol], [len(node_symbol)], dataset)[0], dataset)
# calculate keep probability
sample_transition_count= real_n_vertices + len(edges)/2
keep_prob= float(sample_transition_count)/((real_n_vertices + len(edges)/2) * params["bfs_path_count"]) # to form a binomial distribution
while len(queue) > 0:
node_in_focus=queue.popleft()
current_adj_list=graph[node_in_focus]
# sort (canonical order) it or shuffle (random order) it
if not params["path_random_order"]:
current_adj_list=sorted(current_adj_list)
else:
random.shuffle(current_adj_list)
for neighbor, edge_type in current_adj_list:
# Add this edge if the color of neighbor node is not 2
if color[neighbor]<2:
update_one_step_freqs(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type,
edge_type_masks, valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, False, params, params["check_overlap_edge"], new_mol,
up_to_date_adj_mat,keep_prob, new_compound_frequences, new_compound_frequences_edge, freq_dict)
# Add the edge and obtain a new adj mat
up_to_date_adj_mat=genereate_incremental_adj(
up_to_date_adj_mat, node_in_focus, neighbor, edge_type)
# suppose the edge is selected and update valences after adding the
valences[node_in_focus]-=(edge_type + 1)
valences[neighbor]-=(edge_type + 1)
# update the incremental mol
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[edge_type])
# Explore neighbor nodes
if color[neighbor]==0:
queue.append(neighbor)
color[neighbor]=1
# local stop here. We move on to another node for exploration or stop completely
update_one_step_freqs(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, None, None, edge_type_masks,
valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, True, params, params["check_overlap_edge"],
new_mol, up_to_date_adj_mat,keep_prob, new_compound_frequences, new_compound_frequences_edge, freq_dict)
color[node_in_focus]=2
return incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features, new_compound_frequences, new_compound_frequences_edge