-
Notifications
You must be signed in to change notification settings - Fork 5
/
capgen.py
62 lines (50 loc) · 2.33 KB
/
capgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# download checkpoint model from http://cs.stanford.edu/people/karpathy/neuraltalk/
import os
import numpy as np
from imagernn.imagernn_utils import decodeGenerator
import _pickle as pickle
from keras.applications import VGG16,imagenet_utils
from keras.preprocessing.image import load_img,img_to_array
from keras.models import Model
import tensorflow as tf
preprocess = imagenet_utils.preprocess_input
os.environ['CUDA_VISIBLE_DEVICES'] = ''
FILE_DIR = os.path.dirname(os.path.realpath(__file__))
CHECKPOINT_PATH = os.path.join(FILE_DIR, 'models','flickr8k_cnn_lstm_v1.p')
class CaptionGenerator:
def __init__(self):
self.checkpoint = pickle.load(open(CHECKPOINT_PATH, 'rb'),encoding='latin1')
self.checkpoint_params = self.checkpoint['params']
self.language_model = self.checkpoint['model']
self.ixtoword = self.checkpoint['ixtoword']
model = VGG16(weights="imagenet")
self.visual_model = Model(inputs=model.input,outputs=model.layers[21].output)
self.visual_model._make_predict_function()
self.graph = tf.get_default_graph()
self.BEAM_SIZE = 2
def convert_img_to_vector(self,img_path):
image = load_img(img_path,target_size=(224,224))
image = img_to_array(image)
image = np.expand_dims(image,axis=0)
image = preprocess(image)
return image
def get_image_feature(self,img_path):
feats = np.transpose(self.visual_model.predict(self.convert_img_to_vector(img_path)))
return feats
def predict(self, features):
BatchGenerator = decodeGenerator(CHECKPOINT_PATH)
img = {}
img['feat'] = features[:, 0]
kwparams = {'beam_size': self.BEAM_SIZE}
Ys = BatchGenerator.predict([{'image': img}], self.language_model, self.checkpoint_params, **kwparams)
top_predictions = Ys[0] # take predictions for the first (and only) image we passed in
top_prediction = top_predictions[0] # these are sorted with highest on top
candidate = ' '.join(
[self.ixtoword[ix] for ix in top_prediction[1] if ix > 0]) # ix 0 is the END token, skip that
return candidate
def get_caption(self, file):
with self.graph.as_default():
feat = self.get_image_feature(file)
caption = self.predict(feat)
print(caption)
return caption