Skip to content

Commit

Permalink
add siameseNet
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweeo committed Dec 3, 2017
1 parent a407cfa commit 9947424
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 0 deletions.
12 changes: 12 additions & 0 deletions ConstracstiveLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch.nn import functional as F
import torch
class ContrastiveLoss(torch.nn.Module):

def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin=margin
def forward(self, output1,output2,label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
26 changes: 26 additions & 0 deletions SiameseNetworkDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from torch.utils.data import Dataset
import torch
import numpy as np
from PIL import Image
class SiameseNetworkDataset(Dataset):
def __init__(self, root_dir, label, transform=None):
self.root_dir = root_dir
self.label = label
self.transform=transform
def __getitem__(self, index):

img0 = Image.open(self.root_dir + '/pair1/' + str(index) + '.bmp')
img1 = Image.open(self.root_dir + '/pair2/' + str(index) + '.bmp')

img0=img0.convert("L")
img1=img1.convert("L")

if self.transform is not None:
img0=self.transform(img0)
img1=self.transform(img1)

return img0,img1, torch.from_numpy(np.array([int(self.label[index])],dtype=np.float32))
def __len__(self):
return len(self.label)


19 changes: 19 additions & 0 deletions compute_desc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from net1 import SiameseNetwork
from torchvision import transforms
from PIL import Image
import torch
from torch.autograd import Variable

net=SiameseNetwork()

net.load_state_dict(torch.load('model.pt'))
transform=transforms.ToTensor()
img=Image.open('tmp.bmp')
img=img.convert('L')
img=transform(img)
img=img.unsqueeze(0)
img.view(img.size()[0],-1)
_,output=net(Variable(img),Variable(img))
print output


48 changes: 48 additions & 0 deletions extract_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

from scipy import misc

imglist='liberty/m50_200000_200000_0.txt'

with open(imglist,'r') as f:
content=f.read();

lines=content.split('\n')

num=len(lines)-1

f.close()

f=open('label.txt','w');

for i in range(num):
if (i%1000==0):
print i

list=lines[i].split(' ')
id1=int(list[0])
id2=int(list[3])
label=(list[1]==list[4])


#first patch
m=id1/256
count=(id1)%256
x=(count)%16
y=count/16
img=misc.imread('liberty/patches'+str(m).zfill(4)+'.bmp')
x_=x*64
y_=y*64
patch = img[y_:y_ + 64, x_:x_ + 64]
misc.imsave('pair1/'+str(i)+'.bmp',patch)

#second patch
m=id2/256
count=(id2)%256
x=(count)%16
y=count/16
img=misc.imread('liberty/patches'+str(m).zfill(4)+'.bmp')
x_=x*64
y_=y*64
patch=img[y_:y_+64,x_:x_+64]
misc.imsave('pair2/'+str(i)+'.bmp',patch)
f.write(str(int(label)))
Binary file added model.pt
Binary file not shown.
30 changes: 30 additions & 0 deletions net1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch.nn as nn

class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.Conv2d(1,32, kernel_size=7),
nn.MaxPool2d(2,stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size= 6),
nn.MaxPool2d(3,stride=3),
nn.ReLU(inplace=True),
nn.Conv2d(64,128,kernel_size= 5,),
nn.MaxPool2d(4, stride=4)
)
self.fc1=nn.Sequential(
nn.Linear(128,128)
)
def forward_once(self,x):
output=self.cnn1(x)
output=output.view(output.size()[0],-1)
output =self.fc1(output)
return output

def forward(self, input1,input2):
output1=self.forward_once(input1)
output2=self.forward_once(input2)
return output1,output2


30 changes: 30 additions & 0 deletions visualize_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torch.utils.data import DataLoader
import torchvision
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as dset
from SiameseNetworkDataset import SiameseNetworkDataset



def imshow(img,text,should_save=False):
npimg = img.numpy()
plt.axis("off")
if text:
plt.text(75, 8, text, style='italic',fontweight='bold',
bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

f=open('label.txt','r');
label=f.read()
dataset=SiameseNetworkDataset('.',label,transform=transforms.ToTensor())

vis_dataloader=DataLoader(dataset,shuffle=True,batch_size=8)
dataiter=iter(vis_dataloader)
example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated),'img')
print(example_batch[2].numpy())

0 comments on commit 9947424

Please sign in to comment.