-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathResUnet.py
72 lines (56 loc) · 2.25 KB
/
ResUnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from tensorflow.keras.layers import Conv2D, UpSampling2D, Concatenate, Input
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
def bn_act(x, act=True):
x = tf.keras.layers.BatchNormalization()(x)
if act == True:
x = tf.keras.layers.Activation('relu')(x)
return x
def conv_block(x, filters, kernel_size=3, padding='same', strides=1):
conv = bn_act(x)
conv = Conv2D(filters, kernel_size, padding=padding, strides=strides)(conv)
return conv
def stem(x, filters, kernel_size=3, padding='same', strides=1):
conv = Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
conv = conv_block(conv, filters, kernel_size, padding, strides)
shortcut = Conv2D(filters, kernel_size=1, padding=padding, strides=strides)(x)
shortcut = bn_act(shortcut, act=False)
output = Add()([conv, shortcut])
return output
def residual_block(x, filters, kernel_size=3, padding='same', strides=1):
res = conv_block(x, filters, kernel_size, padding, strides)
res = conv_block(res, filters, kernel_size, padding, 1)
shortcut = Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
shortcut = bn_act(shortcut, act=False)
output = Add()([shortcut, res])
return output
def upsample_concat_block(x, xskip):
u = UpSampling2D((2,2))(x)
c = Concatenate()([u, xskip])
return c
def ResUNet(img_h, img_w):
f = [16, 32, 64, 128, 256]
inputs = Input((img_h, img_w, 3))
## Encoder
e0 = inputs
e1 = stem(e0, f[0])
e2 = residual_block(e1, f[1], strides=2)
e3 = residual_block(e2, f[2], strides=2)
e4 = residual_block(e3, f[3], strides=2)
e5 = residual_block(e4, f[4], strides=2)
## Bridge
b0 = conv_block(e5, f[4], strides=1)
b1 = conv_block(b0, f[4], strides=1)
## Decoder
u1 = upsample_concat_block(b1, e4)
d1 = residual_block(u1, f[4])
u2 = upsample_concat_block(d1, e3)
d2 = residual_block(u2, f[3])
u3 = upsample_concat_block(d2, e2)
d3 = residual_block(u3, f[2])
u4 = upsample_concat_block(d3, e1)
d4 = residual_block(u4, f[1])
outputs = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(d4)
model = Model(inputs, outputs)
return model
model = ResUNet(256, 256)