-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathu2net_portrait-keras.py
50 lines (33 loc) · 987 Bytes
/
u2net_portrait-keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from keras.models import load_model
import cv2
import numpy as np
import sys
input = sys.argv[1]
output = sys.argv[2]
model = './u2net_portrait_keras.h5'
# load model
u2netp_keras = load_model(model, compile=False)
# load image
image = cv2.imread(input)
# normalize input image
input = cv2.resize(image, (512, 512), interpolation=cv2.INTER_CUBIC)
tmpImg = np.zeros((512, 512, 3))
input = input/np.max(input)
tmpImg[:,:,0] = (input[:,:,2]-0.406)/0.225
tmpImg[:,:,1] = (input[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (input[:,:,0]-0.485)/0.229
# convert BGR to RGB
tmpImg = tmpImg.transpose((2, 0, 1))
tmpImg = tmpImg[np.newaxis,:,:,:]
# predict
d1,d2,d3,d4,d5,d6,d7 = u2netp_keras.predict(tmpImg)
pred = np.array(1.0 - d1[:,0,:,:])[0]
# normalize
ma = np.max(pred)
mi = np.min(pred)
pred = (pred-mi)/(ma-mi)
pred = pred.squeeze()
pred = (pred*255).astype(np.uint8)
out = cv2.resize(pred, image.shape[1::-1], interpolation=cv2.INTER_CUBIC)
# save image
cv2.imwrite(output, out)