Skip to content

Commit

Permalink
add cpu mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweeo committed Dec 7, 2017
2 parents d966f77 + 5dc9d5f commit 967d698
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 50 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Learned-SIFT-Descriptor
HW6 for 16-720
# Prerequisites
* MATLAB
* Anaconda
* PyTorch: conda install pytorch torchvision -c soumith

# To run the code
run testMatch.m

5 changes: 1 addition & 4 deletions computeLearnedSIFT.m
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
function [locs,desc] = computeLearnedSIFT(im, GaussianPyramid, locsDoG, k, levels)
[R,C,L] = size(GaussianPyramid);
patchWidth = 9;
locs = zeros(0,3);
norm_patch_size = 64;
patches=zeros(1,64,64);
for i=1:size(locsDoG,1)
level = find(levels==locsDoG(i,3));
patch_half_size = floor(8*k^(1+levels(level)-levels(1)));

% bound = floor(sqrt(2)*patch_half_size
bound = 32;
if locsDoG(i,1)>=bound && locsDoG(i,2)>=bound && locsDoG(i,1)<=C-bound && locsDoG(i,2)<=R-bound
Expand Down
33 changes: 0 additions & 33 deletions computeLearnedSIFT.m~

This file was deleted.

6 changes: 3 additions & 3 deletions compute_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import numpy as np
def compute_desc():
net=SiameseNetwork().cuda()
net=SiameseNetwork()

net.load_state_dict(torch.load('model.pt'))
net.train(False)
Expand All @@ -20,10 +20,10 @@ def compute_desc():
img=img.reshape(1,1,img.shape[0],img.shape[1])
X=torch.from_numpy(img)

_,output=net(Variable(X.float()).cuda(),Variable(X.float()).cuda())
_,output=net(Variable(X.float()),Variable(X.float()))

#print 'time:'+str(end-start)
mat[i]=output.data.cpu().numpy()
mat[i]=output.data.numpy()
scipy.io.savemat('learned_desc.mat',mdict={'desc':mat})
# print mat.tolist()
return 0
Expand Down
32 changes: 32 additions & 0 deletions compute_desc_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from net1 import SiameseNetwork
from torchvision import transforms
from PIL import Image
import torch
from torch.autograd import Variable
import scipy.io
import time
import numpy as np
def compute_desc():
net=SiameseNetwork().cuda()

net.load_state_dict(torch.load('model.pt'))
net.train(False)
# transform=transforms.ToTensor()
imgs=scipy.io.loadmat('patch.mat')['patches']
N=len(imgs)
mat=np.zeros([N-1,128])
for i in range(N-1):
img=imgs[i]
img=img.reshape(1,1,img.shape[0],img.shape[1])
X=torch.from_numpy(img)

_,output=net(Variable(X.float()).cuda(),Variable(X.float()).cuda())

#print 'time:'+str(end-start)
mat[i]=output.data.cpu().numpy()
scipy.io.savemat('learned_desc.mat',mdict={'desc':mat})
# print mat.tolist()
return 0
if __name__ == "__main__":
compute_desc()

Binary file modified learned_desc.mat
Binary file not shown.
Binary file modified patch.mat
Binary file not shown.
7 changes: 4 additions & 3 deletions siftMatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
% into desc1 and the second column are indices into desc2

if nargin<3
ratio = .8;
ratio = .6;
end

% compute the pairwise Hamming distance
Expand All @@ -16,8 +16,9 @@

% suprress match between descriptors that are not distriminative.
r = D(1,:)./D(2,:);
ix = ix((r < ratio & D(1,:)<0.5) | isnan(r));
I2 = I(1,(r < ratio & D(1,:)<0.5) | isnan(r));

ix = ix((r < ratio)& D(1,:)<2 | isnan(r));
I2 = I(1,(r < ratio)& D(1,:)<2 | isnan(r));

%output
matches = [ix' I2'];
Expand Down
9 changes: 9 additions & 0 deletions siftMatch.m~
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ function [matches] = siftMatch(desc1, desc2, ratio)
% into desc1 and the second column are indices into desc2

if nargin<3
<<<<<<< HEAD
ratio = .;
=======
ratio = .7;
>>>>>>> 5dc9d5f515babe0140f91bc4ea2a6b668de3ceee
end

% compute the pairwise Hamming distance
Expand All @@ -16,8 +20,13 @@ ix = 1:size(desc1,1);

% suprress match between descriptors that are not distriminative.
r = D(1,:)./D(2,:);
<<<<<<< HEAD
ix = ix(r < ratio | isnan(r));
I2 = I(1,r < ratio | isnan(r));
=======
ix = ix((r < ratio)& D(1,:)<2 | isnan(r));
I2 = I(1,(r < ratio)& D(1,:)<2| isnan(r));
>>>>>>> 5dc9d5f515babe0140f91bc4ea2a6b668de3ceee

%output
matches = [ix' I2'];
Expand Down
12 changes: 5 additions & 7 deletions testMatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@
% im1 = imread('../data/pf_scan_scaled.jpg');
im1 = im2double(im1);
if size(im1,3)==3
im1= rgb2gray(im1);
im1_gray= rgb2gray(im1);
end
[locs1, desc1] = learned_siftLite(im1);
[locs1, desc1] = learned_siftLite(im1_gray);

%im2 = imread('../data/chickenbroth_01.jpg');
im2 = imrotate(im1,20);
im2_gray=imrotate(im1_gray,20);
% im2 = imread('../data/incline_R.png');
% im2 = imread('../data/pf_stand.jpg');
im2 = im2double(im2);
if size(im2,3)==3
im2= rgb2gray(im2);
end
[locs2, desc2] = learned_siftLite(im2);
[locs2, desc2] = learned_siftLite(im2_gray);

[matches] = siftMatch(desc1, desc2);
plotMatches(im1, im2, matches, locs1, locs2)
img=getNewImg(im1, im2, matches, locs1, locs2);

0 comments on commit 967d698

Please sign in to comment.