-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
85 lines (65 loc) · 2.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import sys
import os
import torch
sys.path.append('/opt/cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from model import EncoderCNN, DecoderRNN
def clean_sentence(output):
sentence = ''
for idx in output:
word = str(data_loader.dataset.vocab.idx2word[idx])
if word != "<start>" and word != "<end>": #cleaning up
sentence = sentence + " " + word
sentence = sentence.strip() #get rid of the space at the beginning
sentence = sentence[0].upper() + sentence[1:-2] + '.' #capital letter start and no space before end dot.
return sentence
def get_prediction():
orig_image, image = next(iter(data_loader))
plt.imshow(np.squeeze(orig_image))
plt.title('Sample Image')
plt.show()
image = image.to(device)
features = encoder(image).unsqueeze(1)
output = decoder.sample(features)
sentence = clean_sentence(output)
print(sentence)
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
#transforms.RandomHorizontalFlip(), # no reason to flip
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Create the data loader.
data_loader = get_loader(transform=transform_test,
mode='test')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
encoder_file = './weights/encoder-3.pkl'
decoder_file = './weights/decoder-3.pkl'
embed_size = 512
hidden_size = 512
vocab_size = len(data_loader.dataset.vocab)
print(vocab_size)
encoder = EncoderCNN(embed_size)
encoder.eval()
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
decoder.eval()
encoder.load_state_dict(torch.load(os.path.join('./models', encoder_file)))
decoder.load_state_dict(torch.load(os.path.join('./models', decoder_file)))
encoder.to(device)
decoder.to(device)
image = image.to(device)
features = encoder(image).unsqueeze(1)
output = decoder.sample(features)
print('example output:', output)
#Verify everyithing is ok with the loaded model
assert (type(output)==list), "Output needs to be a Python list"
assert all([type(x)==int for x in output]), "Output should be a list of integers."
assert all([x in data_loader.dataset.vocab.idx2word for x in output]), "Each entry in the output needs to correspond to an integer that indicates a token in the vocabulary."
#Test for one picture. Rerun for different results
get_prediction()