This repository has been archived by the owner on May 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsampler.py
95 lines (73 loc) · 2.38 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torch.autograd import Variable
from PIL import Image
import torchvision.transforms as transforms
from model import NetG
import pickle
# convert to PIL Image
trans_toPIL = transforms.ToPILImage()
# load the model
checkpoint_path = 'checkpoints/netG__epoch_100.pth'
n_l = 150
n_z = 100
n_c = 128
netG = NetG(n_z=n_z, n_l=n_l, n_c=n_c)
netG.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
def generate_from_caption():
caption_file = "enc_text.pkl"
# load encoded captions
train_ids = pickle.load(open(caption_file, 'rb'))
num_captions = len(train_ids['features'])
num_images = 2
# create random noise
#create random caption
skv = Variable(torch.randn(num_images,4800))
skv.data.normal_(0,1.1)
for i in range(num_captions):
noise = Variable(torch.randn(num_images,n_z))
noise.data.normal_(0,2)
caption = Variable(torch.from_numpy(train_ids['features'][i]))
for j in range(num_images):
out = netG(noise[j].view(1,noise.size(1)),caption.view(1,caption.size(0)))
img = trans_toPIL(out.data[0])
img.save(str(i)+str(j)+'.png')
# other fun experiments
def interpolate(inb=5):
cap1 = Variable(torch.randn(1,4800))
cap2 = Variable(torch.randn(1,4800))
cap1.data.normal_(0,5)
cap2.data.normal_(0,5)
for i in range(inb):
alpha = i/float(inb)
cap = alpha*cap1 + (1-alpha)*cap2
noise = Variable(torch.rand(1,n_z))
noise.data.normal_(0,1)
out = netG(noise,cap)
img = trans_toPIL(out.data[0])
img.save('interp'+str(i)+'.png')
def addDiff():
cap1 = Variable(torch.randn(1,4800))
cap2 = Variable(torch.randn(1,4800))
cap3 = Variable(torch.randn(1,4800))
cap1.data.normal_(0,5)
cap2.data.normal_(0,5)
cap3.data.normal_(0,5)
diff = cap1-cap2
final = cap3+diff
noise = Variable(torch.rand(1,100))
noise.data.normal_(0,1)
out = netG(noise,cap1)
img = trans_toPIL(out.data[0])
img.save('im1.png')
out = netG(noise,cap2)
img = trans_toPIL(out.data[0])
img.save('im2.png')
out = netG(noise,cap3)
img = trans_toPIL(out.data[0])
img.save('im3.png')
out = netG(noise,final)
img = trans_toPIL(out.data[0])
img.save('final.png')
out = netG(noise,diff)
img = trans_toPIL(out.data[0])
img.save('diff.png')