-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathprediction.py
58 lines (42 loc) · 1.74 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
import base64
model_dir = 'models/openimages_v4_ssd_mobilenet_v2_1'
saved_model = tf.saved_model.load(model_dir)
detector = saved_model.signatures['default']
def predict(body):
base64img = body.get('image')
img_bytes = base64.decodebytes(base64img.encode())
detections = detect(img_bytes)
cleaned = clean_detections(detections)
return { 'detections': cleaned }
def detect(img):
image = tf.image.decode_jpeg(img, channels=3)
converted_img = tf.image.convert_image_dtype(image, tf.float32)[tf.newaxis, ...]
result = detector(converted_img)
num_detections = len(result["detection_scores"])
output_dict = {key:value.numpy().tolist() for key, value in result.items()}
output_dict['num_detections'] = num_detections
return output_dict
def clean_detections(detections):
cleaned = []
max_boxes = 10
num_detections = min(detections['num_detections'], max_boxes)
for i in range(0, num_detections):
d = {
'box': {
'yMin': detections['detection_boxes'][i][0],
'xMin': detections['detection_boxes'][i][1],
'yMax': detections['detection_boxes'][i][2],
'xMax': detections['detection_boxes'][i][3]
},
'class': detections['detection_class_entities'][i].decode('utf-8'),
'label': detections['detection_class_entities'][i].decode('utf-8'),
'score': detections['detection_scores'][i],
}
cleaned.append(d)
return cleaned
def preload_model():
blank_jpg = tf.io.read_file('blank.jpeg')
blank_img = tf.image.decode_jpeg(blank_jpg, channels=3)
detector(tf.image.convert_image_dtype(blank_img, tf.float32)[tf.newaxis, ...])
preload_model()