-
Notifications
You must be signed in to change notification settings - Fork 5
/
infer.py
executable file
·48 lines (36 loc) · 1.16 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/bin/python3
import tensorflow as tf
import numpy as np
import data
import model
from glob import glob
checkpoint = "ckpts/model.ckpt"
def main():
parser = data.Parser('./symbols')
m = model.Model(input_size=parser.input_size, output_size=parser.output_size)
m.infer()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
if len(glob(checkpoint + "*")) > 0:
saver.restore(sess, checkpoint)
else:
print("No model found!")
return
while True:
try:
input_ = input('in> ')
except EOFError:
print("\nBye!")
break
input_ids = parser.parse_input(input_)
feed = { m.input_data: np.expand_dims(input_ids, 0) }
output_ids, align_h = sess.run([
m.output_ids,
m.final_state[2].alignment_history.concat()
], feed_dict=feed)
print(parser.compose_output(output_ids[0]))
print("\nAttention alignment:")
print(np.argmax(align_h, axis=1))
print()
main()