forked from leondgarse/Keras_insightface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
482 lines (409 loc) · 21.9 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
def print_buildin_models():
print(
"""
>>>> buildin_models
MXNet version resnet: mobilenet_m1, r18, r34, r50, r100, r101, se_r34, se_r50, se_r100
Keras application: mobilenet, mobilenetv2, resnet50, resnet50v2, resnet101, resnet101v2, resnet152, resnet152v2
EfficientNet: efficientnetb[0-7], efficientnetl2,
Custom 1: ghostnet, mobilefacenet, mobilenetv3_small, mobilenetv3_large, se_mobilefacenet, se_resnext
Or other names from keras.applications like DenseNet121 / InceptionV3 / NASNetMobile / VGG19.
""",
end="",
)
def __init_model_from_name__(name, input_shape=(112, 112, 3), weights="imagenet", **kwargs):
name_lower = name.lower()
""" Basic model """
if name_lower == "mobilenet":
xx = keras.applications.MobileNet(input_shape=input_shape, include_top=False, weights=weights, **kwargs)
elif name_lower == "mobilenet_m1":
from backbones import mobilenet_m1
xx = mobilenet_m1.MobileNet(input_shape=input_shape, include_top=False, weights=None, **kwargs)
elif name_lower == "mobilenetv2":
xx = keras.applications.MobileNetV2(input_shape=input_shape, include_top=False, weights=weights, **kwargs)
elif "r18" in name_lower or "r34" in name_lower or "r50" in name_lower or "r100" in name_lower or "r101" in name_lower:
from backbones import resnet # MXNet insightface version resnet
use_se = True if name_lower.startswith("se_") else False
model_name = "ResNet" + name_lower[4:] if use_se else "ResNet" + name_lower[1:]
use_se = kwargs.pop("use_se", use_se)
model_class = getattr(resnet, model_name)
xx = model_class(input_shape=input_shape, classes=0, use_se=use_se, model_name=model_name, **kwargs)
elif name_lower.startswith("resnet"): # keras.applications.ResNetxxx
if name_lower.endswith("v2"):
model_name = "ResNet" + name_lower[len("resnet") : -2] + "V2"
else:
model_name = "ResNet" + name_lower[len("resnet") :]
model_class = getattr(keras.applications, model_name)
xx = model_class(weights=weights, include_top=False, input_shape=input_shape, **kwargs)
elif name_lower.startswith("efficientnet"):
# import tensorflow.keras.applications.efficientnet as efficientnet
from backbones import efficientnet
model_name = "EfficientNet" + name_lower[-2:].upper()
model_class = getattr(efficientnet, model_name)
xx = model_class(weights=weights, include_top=False, input_shape=input_shape, **kwargs) # or weights='imagenet'
elif name_lower.startswith("se_resnext"):
from keras_squeeze_excite_network import se_resnext
if name_lower.endswith("101"): # se_resnext101
depth = [3, 4, 23, 3]
else: # se_resnext50
depth = [3, 4, 6, 3]
xx = se_resnext.SEResNextImageNet(weights=weights, input_shape=input_shape, include_top=False, depth=depth)
elif name_lower.startswith("mobilenetv3"):
model_class = keras.applications.MobileNetV3Small if "small" in name_lower else keras.applications.MobileNetV3Large
xx = model_class(input_shape=input_shape, include_top=False, weights=weights, include_preprocessing=False)
elif "mobilefacenet" in name_lower or "mobile_facenet" in name_lower:
from backbones import mobile_facenet
use_se = True if "se" in name_lower else False
xx = mobile_facenet.mobile_facenet(input_shape=input_shape, include_top=False, name=name, use_se=use_se)
elif name_lower == "ghostnet":
from backbones import ghost_model
xx = ghost_model.GhostNet(input_shape=input_shape, include_top=False, width=1.3, **kwargs)
elif hasattr(keras.applications, name):
model_class = getattr(keras.applications, name)
xx = model_class(weights=weights, include_top=False, input_shape=input_shape, **kwargs)
else:
return None
xx.trainable = True
return xx
# MXNET: bn_momentum=0.9, bn_epsilon=2e-5, TF default: bn_momentum=0.99, bn_epsilon=0.001, PyTorch default: momentum=0.1, eps=1e-05
# MXNET: use_bias=True, scale=False, cavaface.pytorch: use_bias=False, scale=True
def buildin_models(
stem_model,
dropout=1,
emb_shape=512,
input_shape=(112, 112, 3),
output_layer="GDC",
bn_momentum=0.99,
bn_epsilon=0.001,
add_pointwise_conv=False,
use_bias=False,
scale=True,
weights="imagenet",
**kwargs
):
if isinstance(stem_model, str):
xx = __init_model_from_name__(stem_model, input_shape, weights, **kwargs)
name = stem_model
else:
name = stem_model.name
xx = stem_model
if bn_momentum != 0.99 or bn_epsilon != 0.001:
print(">>>> Change BatchNormalization momentum and epsilon default value.")
for ii in xx.layers:
if isinstance(ii, keras.layers.BatchNormalization):
ii.momentum, ii.epsilon = bn_momentum, bn_epsilon
xx = keras.models.clone_model(xx)
inputs = xx.inputs[0]
nn = xx.outputs[0]
if add_pointwise_conv: # Model using `pointwise_conv + GDC` / `pointwise_conv + E` is smaller than `E`
nn = keras.layers.Conv2D(512, 1, use_bias=False, padding="same")(nn)
nn = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon)(nn)
nn = keras.layers.PReLU(shared_axes=[1, 2])(nn)
if output_layer == "E":
""" Fully Connected """
nn = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, name="E_batchnorm")(nn)
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Flatten(name="E_flatten")(nn)
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="E_dense")(nn)
elif output_layer == "GAP":
""" GlobalAveragePooling2D """
nn = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, name="GAP_batchnorm")(nn)
nn = keras.layers.GlobalAveragePooling2D(name="GAP_pool")(nn)
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="GAP_dense")(nn)
elif output_layer == "GDC":
""" GDC """
nn = keras.layers.DepthwiseConv2D(int(nn.shape[1]), depth_multiplier=1, use_bias=False, name="GDC_dw")(nn)
# nn = keras.layers.Conv2D(512, int(nn.shape[1]), use_bias=False, padding="valid", groups=512)(nn)
nn = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, name="GDC_batchnorm")(nn)
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Conv2D(emb_shape, 1, use_bias=use_bias, kernel_initializer="glorot_normal", name="GDC_conv")(nn)
nn = keras.layers.Flatten(name="GDC_flatten")(nn)
# nn = keras.layers.Dense(emb_shape, activation=None, use_bias=use_bias, kernel_initializer="glorot_normal", name="GDC_dense")(nn)
elif output_layer == "F":
""" F, E without first BatchNormalization """
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Flatten(name="E_flatten")(nn)
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="E_dense")(nn)
# `fix_gamma=True` in MXNet means `scale=False` in Keras
embedding = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, scale=scale, name="pre_embedding")(nn)
embedding_fp32 = keras.layers.Activation("linear", dtype="float32", name="embedding")(embedding)
basic_model = keras.models.Model(inputs, embedding_fp32, name=xx.name)
return basic_model
class NormDense(keras.layers.Layer):
def __init__(self, units=1000, kernel_regularizer=None, loss_top_k=1, **kwargs):
super(NormDense, self).__init__(**kwargs)
self.init = keras.initializers.glorot_normal()
self.units, self.loss_top_k = units, loss_top_k
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
self.supports_masking = False
def build(self, input_shape):
self.w = self.add_weight(
name="norm_dense_w",
shape=(input_shape[-1], self.units * self.loss_top_k),
initializer=self.init,
trainable=True,
regularizer=self.kernel_regularizer,
)
super(NormDense, self).build(input_shape)
def call(self, inputs, **kwargs):
norm_w = K.l2_normalize(self.w, axis=0)
inputs = K.l2_normalize(inputs, axis=1)
output = K.dot(inputs, norm_w)
if self.loss_top_k > 1:
output = K.reshape(output, (-1, self.units, self.loss_top_k))
output = K.max(output, axis=2)
return output
def compute_output_shape(self, input_shape):
return (input_shape[0], self.units)
def get_config(self):
config = super(NormDense, self).get_config()
config.update(
{
"units": self.units,
"loss_top_k": self.loss_top_k,
"kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer),
}
)
return config
@classmethod
def from_config(cls, config):
return cls(**config)
def add_l2_regularizer_2_model(model, weight_decay, custom_objects={}, apply_to_batch_normal=False, apply_to_bias=False):
# https://github.com/keras-team/keras/issues/2717#issuecomment-456254176
if 0:
regularizers_type = {}
for layer in model.layers:
rrs = [kk for kk in layer.__dict__.keys() if "regularizer" in kk and not kk.startswith("_")]
if len(rrs) != 0:
# print(layer.name, layer.__class__.__name__, rrs)
if layer.__class__.__name__ not in regularizers_type:
regularizers_type[layer.__class__.__name__] = rrs
print(regularizers_type)
for layer in model.layers:
attrs = []
if isinstance(layer, keras.layers.Dense) or isinstance(layer, keras.layers.Conv2D):
# print(">>>> Dense or Conv2D", layer.name, "use_bias:", layer.use_bias)
attrs = ["kernel_regularizer"]
if apply_to_bias and layer.use_bias:
attrs.append("bias_regularizer")
elif isinstance(layer, keras.layers.DepthwiseConv2D):
# print(">>>> DepthwiseConv2D", layer.name, "use_bias:", layer.use_bias)
attrs = ["depthwise_regularizer"]
if apply_to_bias and layer.use_bias:
attrs.append("bias_regularizer")
elif isinstance(layer, keras.layers.SeparableConv2D):
# print(">>>> SeparableConv2D", layer.name, "use_bias:", layer.use_bias)
attrs = ["pointwise_regularizer", "depthwise_regularizer"]
if apply_to_bias and layer.use_bias:
attrs.append("bias_regularizer")
elif apply_to_batch_normal and isinstance(layer, keras.layers.BatchNormalization):
# print(">>>> BatchNormalization", layer.name, "scale:", layer.scale, ", center:", layer.center)
if layer.center:
attrs.append("beta_regularizer")
if layer.scale:
attrs.append("gamma_regularizer")
elif apply_to_batch_normal and isinstance(layer, keras.layers.PReLU):
# print(">>>> PReLU", layer.name)
attrs = ["alpha_regularizer"]
for attr in attrs:
if hasattr(layer, attr) and layer.trainable:
setattr(layer, attr, keras.regularizers.L2(weight_decay / 2))
# So far, the regularizers only exist in the model config. We need to
# reload the model so that Keras adds them to each layer's losses.
# temp_weight_file = "tmp_weights.h5"
# model.save_weights(temp_weight_file)
# out_model = keras.models.model_from_json(model.to_json(), custom_objects=custom_objects)
# out_model.load_weights(temp_weight_file, by_name=True)
# os.remove(temp_weight_file)
# return out_model
return keras.models.clone_model(model)
def replace_ReLU_with_PReLU(model, target_activation="PReLU", **kwargs):
from tensorflow.keras.layers import ReLU, PReLU, Activation
def convert_ReLU(layer):
# print(layer.name)
if isinstance(layer, ReLU) or (isinstance(layer, Activation) and layer.activation == keras.activations.relu):
if target_activation == "PReLU":
layer_name = layer.name.replace("_relu", "_prelu")
print(">>>> Convert ReLU:", layer.name, "-->", layer_name)
# Default initial value in mxnet and pytorch is 0.25
return PReLU(shared_axes=[1, 2], alpha_initializer=tf.initializers.Constant(0.25), name=layer_name, **kwargs)
elif isinstance(target_activation, str):
layer_name = layer.name.replace("_relu", "_" + target_activation)
print(">>>> Convert ReLU:", layer.name, "-->", layer_name)
return Activation(activation=target_activation, name=layer_name, **kwargs)
else:
act_class_name = target_activation.__name__
layer_name = layer.name.replace("_relu", "_" + act_class_name)
print(">>>> Convert ReLU:", layer.name, "-->", layer_name)
return target_activation(**kwargs)
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=convert_ReLU)
class AconC(keras.layers.Layer):
"""
- [Github nmaac/acon](https://github.com/nmaac/acon/blob/main/acon.py)
- [Activate or Not: Learning Customized Activation, CVPR 2021](https://arxiv.org/pdf/2009.04759.pdf)
"""
def __init__(self, p1=1, p2=0, beta=1, **kwargs):
super(AconC, self).__init__(**kwargs)
self.p1_init = tf.initializers.Constant(p1)
self.p2_init = tf.initializers.Constant(p2)
self.beta_init = tf.initializers.Constant(beta)
self.supports_masking = False
def build(self, input_shape):
self.p1 = self.add_weight(name="p1", shape=(1, 1, 1, input_shape[-1]), initializer=self.p1_init, trainable=True)
self.p2 = self.add_weight(name="p2", shape=(1, 1, 1, input_shape[-1]), initializer=self.p2_init, trainable=True)
self.beta = self.add_weight(name="beta", shape=(1, 1, 1, input_shape[-1]), initializer=self.beta_init, trainable=True)
super(AconC, self).build(input_shape)
def call(self, inputs, **kwargs):
p1 = inputs * self.p1
p2 = inputs * self.p2
beta = inputs * self.beta
return p1 * tf.nn.sigmoid(beta) + p2
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
return super(AconC, self).get_config()
@classmethod
def from_config(cls, config):
return cls(**config)
class SAMModel(tf.keras.models.Model):
"""
Arxiv article: [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/pdf/2010.01412.pdf)
Implementation by: [Keras SAM (Sharpness-Aware Minimization)](https://qiita.com/T-STAR/items/8c3afe3a116a8fc08429)
Usage is same with `keras.modeols.Model`: `model = SAMModel(inputs, outputs, rho=sam_rho, name=name)`
"""
def __init__(self, *args, rho=0.05, **kwargs):
super().__init__(*args, **kwargs)
self.rho = tf.constant(rho, dtype=tf.float32)
def train_step(self, data):
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data
# 1st step
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
norm = tf.linalg.global_norm(gradients)
scale = self.rho / (norm + 1e-12)
e_w_list = []
for v, grad in zip(trainable_vars, gradients):
e_w = grad * scale
v.assign_add(e_w)
e_w_list.append(e_w)
# 2nd step
with tf.GradientTape() as tape:
y_pred_adv = self(x, training=True)
loss_adv = self.compiled_loss(y, y_pred_adv, sample_weight=sample_weight, regularization_losses=self.losses)
gradients_adv = tape.gradient(loss_adv, trainable_vars)
for v, e_w in zip(trainable_vars, e_w_list):
v.assign_sub(e_w)
# optimize
self.optimizer.apply_gradients(zip(gradients_adv, trainable_vars))
self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics
def replace_add_with_stochastic_depth(model, survivals=(1, 0.8)):
"""
- [Deep Networks with Stochastic Depth](https://arxiv.org/pdf/1603.09382.pdf)
- [tfa.layers.StochasticDepth](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/StochasticDepth)
"""
from tensorflow_addons.layers import StochasticDepth
add_layers = [ii.name for ii in model.layers if isinstance(ii, keras.layers.Add)]
total_adds = len(add_layers)
if isinstance(survivals, float):
survivals = [survivals] * total_adds
elif isinstance(survivals, (list, tuple)) and len(survivals) == 2:
start, end = survivals
survivals = [start - (1 - end) * float(ii) / total_adds for ii in range(total_adds)]
survivals_dict = dict(zip(add_layers, survivals))
def __replace_add_with_stochastic_depth__(layer):
if isinstance(layer, keras.layers.Add):
layer_name = layer.name
new_layer_name = layer_name.replace("_add", "_stochastic_depth")
new_layer_name = layer_name.replace("add_", "stochastic_depth_")
survival_probability = survivals_dict[layer_name]
if survival_probability < 1:
print("Converting:", layer_name, "-->", new_layer_name, ", survival_probability:", survival_probability)
return StochasticDepth(survival_probability, name=new_layer_name)
else:
return layer
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=__replace_add_with_stochastic_depth__)
def replace_stochastic_depth_with_add(model, drop_survival=False):
from tensorflow_addons.layers import StochasticDepth
def __replace_stochastic_depth_with_add__(layer):
if isinstance(layer, StochasticDepth):
layer_name = layer.name
new_layer_name = layer_name.replace("_stochastic_depth", "_lambda")
survival = layer.survival_probability
print("Converting:", layer_name, "-->", new_layer_name, ", survival_probability:", survival)
if drop_survival or not survival < 1:
return keras.layers.Add(name=new_layer_name)
else:
return keras.layers.Lambda(lambda xx: xx[0] + xx[1] * survival, name=new_layer_name)
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=__replace_stochastic_depth_with_add__)
def convert_to_mixed_float16(model, convert_batch_norm=False):
policy = keras.mixed_precision.Policy("mixed_float16")
policy_config = keras.utils.serialize_keras_object(policy)
from tensorflow.keras.layers import InputLayer, Activation
from tensorflow.keras.activations import linear
def do_convert_to_mixed_float16(layer):
if not convert_batch_norm and isinstance(layer, keras.layers.BatchNormalization):
return layer
if not isinstance(layer, InputLayer) and not (isinstance(layer, Activation) and layer.activation == linear):
aa = layer.get_config()
aa.update({"dtype": policy_config})
bb = layer.__class__.from_config(aa)
bb.build(layer.input_shape)
bb.set_weights(layer.get_weights())
return bb
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=do_convert_to_mixed_float16)
def convert_mixed_float16_to_float32(model):
from tensorflow.keras.layers import InputLayer, Activation
from tensorflow.keras.activations import linear
def do_convert_to_mixed_float16(layer):
if not isinstance(layer, InputLayer) and not (isinstance(layer, Activation) and layer.activation == linear):
aa = layer.get_config()
aa.update({"dtype": "float32"})
bb = layer.__class__.from_config(aa)
bb.build(layer.input_shape)
bb.set_weights(layer.get_weights())
return bb
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=do_convert_to_mixed_float16)
def convert_to_batch_renorm(model):
def do_convert_to_batch_renorm(layer):
if isinstance(layer, keras.layers.BatchNormalization):
aa = layer.get_config()
aa.update({"renorm": True, "renorm_clipping": {}, "renorm_momentum": aa["momentum"]})
bb = layer.__class__.from_config(aa)
bb.build(layer.input_shape)
bb.set_weights(layer.get_weights() + bb.get_weights()[-3:])
return bb
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=do_convert_to_batch_renorm)