-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRR_Unet.py
178 lines (143 loc) · 7.57 KB
/
RR_Unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Concatenate, Input
from tensorflow.keras.models import Model
def residual_block_1(inputs, num_filters, strides=1):
x = Conv2D(num_filters, 3, padding="same", strides=strides)(inputs)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same", strides=1)(x)
s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)
x = x + s
return x
def residual_block_2(inputs, num_filters, strides=1):
x = BatchNormalization()(inputs)
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same", strides=strides)(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same", strides=1)(x)
s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)
x1 = BatchNormalization()(s)
x2 = Activation("relu")(x1)
x = x + x2
return x
def decoder_block(inputs, skip_features, num_filters):
x = UpSampling2D((2, 2))(inputs)
x = Concatenate()([x, skip_features])
x = residual_block_2(x, num_filters, strides=1)
return x
def conv2d_bn(x,filters,num_row,num_col,padding='same',stride=1,dilation_rate=1,relu=True):
x = Conv2D(
filters, (num_row, num_col),
strides=(stride,stride),
padding=padding,
dilation_rate=(dilation_rate, dilation_rate),
use_bias=False)(x)
x = BatchNormalization(scale=False)(x)
if relu:
x = Activation("relu")(x)
return x
def BasicRFB(x,input_filters,output_filters,stride=1):
input_filters_div = input_filters//8
branch_0 = conv2d_bn(x,input_filters_div*2,1,1,stride=stride)
branch_0 = conv2d_bn(branch_0,input_filters_div*2,3,3,relu=False)
branch_1 = conv2d_bn(x,input_filters_div,1,1)
branch_1 = conv2d_bn(branch_1,input_filters_div*2,3,3,stride=stride)
branch_1 = conv2d_bn(branch_1,input_filters_div*2,3,3,dilation_rate=3,relu=False)
branch_2 = conv2d_bn(x,input_filters_div,1,1)
branch_2 = conv2d_bn(branch_2,(input_filters_div//2)*3,3,3)
branch_2 = conv2d_bn(branch_2,input_filters_div*2,3,3,stride=stride)
branch_2 = conv2d_bn(branch_2,input_filters_div*2,3,3,dilation_rate=5,relu=False)
out = concatenate([branch_0,branch_1,branch_2],axis=-1)
out = conv2d_bn(out,output_filters,1,1,relu=False)
short = conv2d_bn(x,output_filters,1,1,stride=stride,relu=False)
out = Add()([out, short])
out = Activation("relu")(out)
return out
def BasicRFB_A(x, input_filters, output_filters, stride=1):
input_filters_div = input_filters // 8
branch_0 = conv2d_bn(x, input_filters_div, 1, 1, stride=stride)
branch_0 = conv2d_bn(branch_0, input_filters_div, 3, 3, relu=False)
branch_1 = conv2d_bn(x, input_filters_div, 1, 1)
branch_1 = conv2d_bn(branch_1, input_filters_div, 2, 2)
branch_1 = conv2d_bn(branch_1, input_filters_div, 2, 2)
branch_1 = conv2d_bn(branch_1, input_filters_div, 3, 3, dilation_rate=3, relu=False)
branch_2 = conv2d_bn(x, input_filters_div, 1, 1)
branch_2 = conv2d_bn(branch_2, input_filters_div, 1, 5)
branch_2 = conv2d_bn(branch_2, input_filters_div, 5, 1)
branch_2 = conv2d_bn(branch_2, input_filters_div, 3, 3, dilation_rate=5, relu=False)
out = concatenate([branch_0, branch_1, branch_2], axis=-1)
out = conv2d_bn(out, output_filters, 1, 1, relu=False)
short = conv2d_bn(x, output_filters, 1, 1, stride=stride, relu=False)
out = Lambda(lambda x: x[0] + x[1])([out, short])
out = Activation("relu")(out)
return out
def BasicRFB_B(x, input_filters, output_filters, stride=1):
input_filters_div = input_filters // 8
branch_0 = conv2d_bn(x, input_filters_div * 2, 1, 1, stride=stride)
branch_0 = conv2d_bn(branch_0, input_filters_div * 2, 3, 3, relu=False)
branch_1 = conv2d_bn(x, input_filters_div, 1, 1)
branch_1 = conv2d_bn(branch_1, input_filters_div * 2, 1, 3, stride=stride)
branch_1 = conv2d_bn(branch_1, input_filters_div * 2, 3, 1, stride=stride)
branch_1 = conv2d_bn(branch_1, input_filters_div * 2, 3, 3, dilation_rate=3, relu=False)
branch_2 = conv2d_bn(x, input_filters_div, 1, 1)
branch_2 = conv2d_bn(branch_2, input_filters_div * 2, 3, 3, stride=stride)
branch_2 = conv2d_bn(branch_2, input_filters_div * 2, 3, 3, stride=stride)
branch_2 = conv2d_bn(branch_2, input_filters_div * 2, 3, 3, dilation_rate=5, relu=False)
out = concatenate([branch_0, branch_1, branch_2], axis=-1)
out = conv2d_bn(out, output_filters, 1, 1, relu=False)
short = conv2d_bn(x, output_filters, 1, 1, stride=stride, relu=False)
out = Lambda(lambda x: x[0] + x[1])([out, short])
out = Activation("relu")(out)
return out
def BasicRFB_C(x, input_filters, output_filters, stride=1):
input_filters_div = input_filters // 8
branch_0 = conv2d_bn(x, input_filters_div * 2, 1, 1, stride=stride)
branch_0 = conv2d_bn(branch_0, input_filters_div * 2, 3, 3, relu=False)
branch_1 = conv2d_bn(x, input_filters_div, 1, 1)
branch_1 = conv2d_bn(branch_1, input_filters_div * 2, 1, 3, stride=stride)
branch_1 = conv2d_bn(branch_1, input_filters_div * 2, 3, 1, stride=stride)
branch_1 = conv2d_bn(branch_1, input_filters_div * 2, 3, 3, dilation_rate=3, relu=False)
branch_2 = conv2d_bn(x, input_filters_div, 1, 1)
branch_2 = conv2d_bn(branch_2, input_filters_div * 2, 1, 5, stride=stride)
branch_2 = conv2d_bn(branch_2, input_filters_div * 2, 5, 1, stride=stride)
branch_2 = conv2d_bn(branch_2, input_filters_div * 2, 3, 3, dilation_rate=5, relu=False)
out = concatenate([branch_0, branch_1, branch_2], axis=-1)
out = conv2d_bn(out, output_filters, 1, 1, relu=False)
short = conv2d_bn(x, output_filters, 1, 1, stride=stride, relu=False)
out = Lambda(lambda x: x[0] + x[1])([out, short])
out = Activation("relu")(out)
return out
def BasicRFB_D(x, input_filters, output_filters, stride=1):
input_filters_div = input_filters // 8
branch_0 = conv2d_bn(x, input_filters_div, 1, 1, stride=stride)
branch_0 = conv2d_bn(branch_0, input_filters_div, 3, 3, relu=False)
branch_1 = conv2d_bn(x, input_filters_div, 1, 1)
branch_1 = conv2d_bn(branch_1, input_filters_div, 2, 2)
branch_1 = conv2d_bn(branch_1, input_filters_div, 2, 2)
branch_1 = conv2d_bn(branch_1, input_filters_div, 3, 3, dilation_rate=3, relu=False)
branch_2 = conv2d_bn(x, input_filters_div, 1, 1)
branch_2 = conv2d_bn(branch_2, input_filters_div, 3, 3)
branch_2 = conv2d_bn(branch_2, input_filters_div, 3, 3)
branch_2 = conv2d_bn(branch_2, input_filters_div, 3, 3, dilation_rate=5, relu=False)
out = concatenate([branch_0, branch_1, branch_2], axis=-1)
out = conv2d_bn(out, output_filters, 1, 1, relu=False)
short = conv2d_bn(x, output_filters, 1, 1, stride=stride, relu=False)
out = Lambda(lambda x: x[0] + x[1])([out, short])
out = Activation("relu")(out)
return out
# build model
inputs = Input((im_height, im_width, depth))
# Endocer
s1 = residual_block_1(inputs, 64, strides=1)
s2 = residual_block_2(s1, 128, strides=2)
s3 = residual_block_2(s2, 256, strides=2)
# brige module
s4= residual_block_1(s3, 512, strides=2)
Ra=BasicRFB_A(s4,512,256,stride=1)
b1= Conv2D(256, 3, padding="same", strides=2)(Ra)
# Dedocer
d1 = decoder_block(Ra, s3, 256)
d2 = decoder_block(d1, s2, 128)
d3 = decoder_block(d2, s1, 64)
outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d3)
model = Model(inputs, outputs)