-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcorres_sampler.py
176 lines (141 loc) · 8.38 KB
/
corres_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import numpy as np
import torch
import scipy.io
import random
def pdist(vectors):
distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
dim=1).view(-1, 1)
return distance_matrix
def sample_correspondence(known_correspondence, img1, img2, sample_size=1024):
"""known positive correspondence given by .mat file
sample 1024 positive correspondences and 1024 negative correspondences
known_correspondence{'a':Nx2, 'b':Nx2}, N: number of given correspondences, 2: pixel dim, a: image_a, b:image_b
positive correspondences can be randomly selected,
but we want sample hard examples for negative correspondences
"""
# read input data
# matches_in_1 means data in correspondence in img1
matches_in_1 = np.array(known_correspondence['a'])
# matches_in_2 means data in correspondence in img2
matches_in_2 = np.array(known_correspondence['b'])
# print(matches_in_1.shape)
# randomly select positive correspondences
matches_in_1_random_pos, matches_in_2_random_pos = random_select_positive_matches(matches_in_1, matches_in_2,
num_of_pairs=sample_size)
# store the selected positive correspondences, whose format is identical to the known_correspondence
pos_correspondences = {'a': matches_in_1_random_pos, 'b': matches_in_2_random_pos}
# print(pos_correspondences)
# randomly select negative correspondences
matches_in_1_random_neg, matches_in_2_random_neg = random_select_negative_matches(matches_in_1, matches_in_2,
num_of_pairs=sample_size)
neg_correspondences_random = {'a': matches_in_1_random_neg, 'b': matches_in_2_random_neg}
# neg_correspondences_pos = {'a': matches_in_1_random_neg, 'b': matches_in_2_random_selected_pos}
# print(neg_correspondences_random)
# print(neg_correspondences_pos)
# select the hardest negative correspondences
matches_in_1_hard_neg, matches_in_2_hard_neg = hard_select_negative_matches(matches_in_1, matches_in_2,
num_of_pairs=1024)
neg_correspondences_hard = {'a': matches_in_1_hard_neg, 'b': matches_in_2_hard_neg}
# print(neg_correspondences_hard)
'''
return type:
pos_correspondences: 1024 x 2
neg_correspondences: 1024 x 2
'''
return pos_correspondences, neg_correspondences_random, neg_correspondences_hard
def random_select_positive_matches(matches_in_1, matches_in_2, num_of_pairs=1024):
matches_in_1 = matches_in_1[0]
matches_in_2 = matches_in_2[0]
# # select samples according to the random generated index
rand_idx = torch.randint(0,matches_in_1.shape[0],(num_of_pairs,))
matches_in_1_random_selected = matches_in_1[rand_idx]
matches_in_2_random_selected = matches_in_2[rand_idx]
# return matches_in_1_random_selected, matches_in_2_random_selected
return {'a': matches_in_1_random_selected[None, ...], 'b':matches_in_2_random_selected[None, ...]}
def random_select_negative_matches_whole_image(matches_in_1, matches_in_2, h=768, w=1024, num_of_pairs=1024):
if matches_in_1.shape[0] < num_of_pairs/2:
# for each image, choose half of the points
random_index1 = random.choices(range(0, matches_in_1.shape[0]), k=num_of_pairs//2)
random_index2 = random.choices(range(0, matches_in_2.shape[0]), k=num_of_pairs//2)
else:
random_index1 = random.sample(range(0, matches_in_1.shape[0]), num_of_pairs//2)
random_index2 = random.sample(range(0, matches_in_2.shape[0]), num_of_pairs//2)
# select samples according to the random generated index
matches_in_1_random_selected_part1 = np.array([matches_in_1[index] for index in random_index1])
matches_in_2_random_selected_part1 = np.array([matches_in_2[index] for index in random_index2])
random_index1_part2_x = np.array(random.choices(range(0, w), k=num_of_pairs//2))
random_index1_part2_y = np.array(random.choices(range(0, h), k=num_of_pairs//2))
random_index2_part2_x = np.array(random.choices(range(0, w), k=num_of_pairs//2))
random_index2_part2_y = np.array(random.choices(range(0, h), k=num_of_pairs//2))
matches_in_1_random_selected_part2 = np.stack((random_index1_part2_x, random_index1_part2_y), axis=1)
matches_in_2_random_selected_part2 = np.stack((random_index2_part2_x, random_index2_part2_y), axis=1)
neg_match_in_1 = np.concatenate((matches_in_1_random_selected_part1, matches_in_1_random_selected_part2), axis=0)
neg_match_in_2 = np.concatenate((matches_in_2_random_selected_part1, matches_in_2_random_selected_part2), axis=0)
return neg_match_in_1, neg_match_in_2
def random_select_negative_matches(matches_in_1, matches_in_2, num_of_pairs=1024):
# check the number of correspondences
if matches_in_1.shape[0] < num_of_pairs:
random_index = random.choices(range(0, matches_in_1.shape[0]), k=num_of_pairs)
else:
random_index = random.sample(range(0, matches_in_1.shape[0]), num_of_pairs)
# select samples according to the random generated index
matches_in_1_random_selected = [matches_in_1[index] for index in random_index]
# matches_in_2_random_selected_pos = [matches_in_2[index] for index in random_index]
# generate random neg index which is not equal to pos index
random_index2 = []
for index in random_index:
random_index2.append(get_random(0, matches_in_1.shape[0] - 1, index))
matches_in_2_random_selected = [matches_in_2[index2] for index2 in random_index2]
matches_in_1_random_selected = np.array(matches_in_1_random_selected)
matches_in_2_random_selected = np.array(matches_in_2_random_selected)
# return matches_in_1_random_selected, matches_in_2_random_selected, matches_in_2_random_selected_pos
return matches_in_1_random_selected, matches_in_2_random_selected
# generate random number which is not equal to 'not_equal_num'
def get_random(a, b, not_equal_num):
num = random.randint(a, b)
if num == not_equal_num:
get_random(a, b, not_equal_num)
return num
# select the hardest negative correspondence given positive correspondence
def hard_select_negative_matches(matches_in_1, matches_in_2, num_of_pairs=1024):
# check the number of correspondences
if matches_in_1.shape[1] < num_of_pairs:
return None
# find the hardest correspondence for each randomly selected points in matches_in-1
# and store the index into neg_best_index2_list
neg_best_index2_list = []
for index1 in range(matches_in_1.shape[1]):
d_best = 1000000
neg_best_index2 = 1000000
# print(index)
for index2 in range(matches_in_2.shape[1]):
d = np.linalg.norm(matches_in_1[index1] - matches_in_2[index2])
if (d <= d_best) and (index2 != index1):
d_best = d
neg_best_index2 = index2
neg_best_index2_list.append(neg_best_index2)
# select samples according to neg_best_index2_list
matches_in_2_hard_selected = [matches_in_2[index2] for index2 in neg_best_index2_list]
# matches_in_1_random_selected = np.array(matches_in_1_random_selected)
matches_in_2_hard_selected = np.array(matches_in_2_hard_selected)
# matches_in_2_random_selected_pos = np.array(matches_in_2_random_selected_pos)
return matches_in_2_hard_selected
def corres_sampler():
# scipy.io.savemat(data_corr, {'matches':matches, 'img1':img1, 'img2':img2})
data = scipy.io.loadmat('all_correspondences.mat')
# matches have the form [x, y, x', y'].
matches = np.array(data['matches'])
# print(matches.shape)
# construct desired format: known_correspondence{'a':Nx2, 'b':Nx2}
known_correspondence = {'a': matches[:, 0:2], 'b': matches[:, 2:]}
# print(known_correspondence)
img1 = np.array(data['img1'])
img2 = np.array(data['img2'])
# print(img1.shape)
# select correspondence the output are in the format: pos_correspondences{'a':Nx2, 'b':Nx2},
# neg_correspondences_random{'a':Nx2, 'b':Nx2}, neg_correspondences_hard{'a':Nx2, 'b':Nx2},
pos_correspondences, neg_correspondences_random, neg_correspondences_hard = sample_correspondence(
known_correspondence, img1, img2, sample_size=1024)
print(pos_correspondences)
print(neg_correspondences_random)
print(neg_correspondences_hard)