Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/songhan89/chexpert-aml into…
Browse files Browse the repository at this point in the history
… main
  • Loading branch information
gabriellapauline committed Jul 31, 2021
2 parents 9da1831 + 27c8de4 commit 452bba0
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/data/imgproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def tf_read_image(x_features, filename, label, cnn_model,
image = imgproc.transform(image, transformations)

if cnn_model in ["MobileNetv2_keras",
"MobileNetv2_pop1",
"MobileNetv2_pop2",
"DenseNet121_keras",
"ResNet152_keras"]:
image = tf.image.grayscale_to_rgb(image)
Expand Down
110 changes: 110 additions & 0 deletions src/models/tensorflow_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def MobileNetv2_keras(output_size,
cnn_base = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=not_transfer,
weights='imagenet')
cnn_base.trainable = not_transfer
print(cnn_base.summary())

#inputs
inputs_feature = tf.keras.Input(shape=feature_shape)
Expand Down Expand Up @@ -291,6 +292,114 @@ def MobileNetv2_keras(output_size,

return model

#MobileNetv2 with top 1 layer group removed
def MobileNetv2_pop1(output_size,
not_transfer=False,
feature_shape=(4,),
image_shape=(320,320,3)):
cnn_base = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=not_transfer,
weights='imagenet')
cnn_base.trainable = not_transfer

pop_model = tf.keras.Model(inputs=cnn_base.inputs, outputs=cnn_base.get_layer('block_15_project_BN').output)

print('len(model.layers)', len(pop_model.layers)) #
print(pop_model.summary()) #

#Get config from the final layers in MobileNetv2
conv_1_config = cnn_base.get_layer('Conv_1').get_config()
conv_1_weights = cnn_base.get_layer('Conv_1').get_weights()

conv_1_bn_config = cnn_base.get_layer('Conv_1_bn').get_config()
conv_1_bn_weights = cnn_base.get_layer('Conv_1_bn').get_weights()

out_relu_config = cnn_base.get_layer('out_relu').get_config()

#inputs
inputs_feature = tf.keras.Input(shape=feature_shape)
inputs_image = tf.keras.Input(shape=image_shape)

x1 = pop_model(inputs_image, training=False)
x1 = tf.keras.layers.Conv2D.from_config(conv_1_config)(x1)
x1 = tf.keras.layers.BatchNormalization.from_config(conv_1_bn_config)(x1)
x1 = tf.keras.layers.ReLU.from_config(out_relu_config)(x1)
x1 = tf.keras.layers.GlobalAveragePooling2D()(x1)
x1 = tf.keras.layers.Flatten()(x1)

#branch 2 for the non-image features
x2 = tf.keras.layers.Dense(10)(inputs_feature)
x2 = tf.keras.layers.Activation("relu")(x2)
x2 = tf.keras.layers.Dropout(0.5)(x2)

#conatenate the features
x = tf.keras.layers.Concatenate()([x1, x2])
x = tf.keras.layers.Activation('relu')(x)

# create output layer
x = tf.keras.layers.Dense(output_size)(x)
x = tf.keras.layers.Activation("sigmoid", name='predicted_observations')(x)

# create model class
model = tf.keras.Model(inputs=[inputs_feature, inputs_image],
outputs=x,
name = 'MobileNetv2_keras')

return model

#MobileNetv2 with top 2 layers group removed
def MobileNetv2_pop2(output_size,
not_transfer=False,
feature_shape=(4,),
image_shape=(320,320,3)):
cnn_base = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=not_transfer,
weights='imagenet')
cnn_base.trainable = not_transfer

pop_model = tf.keras.Model(inputs=cnn_base.inputs, outputs=cnn_base.get_layer('block_14_project_BN').output)

print('len(model.layers)', len(pop_model.layers)) #
print(pop_model.summary()) #

#Get config from the final layers in MobileNetv2
conv_1_config = cnn_base.get_layer('Conv_1').get_config()
conv_1_weights = cnn_base.get_layer('Conv_1').get_weights()

conv_1_bn_config = cnn_base.get_layer('Conv_1_bn').get_config()
conv_1_bn_weights = cnn_base.get_layer('Conv_1_bn').get_weights()

out_relu_config = cnn_base.get_layer('out_relu').get_config()

#inputs
inputs_feature = tf.keras.Input(shape=feature_shape)
inputs_image = tf.keras.Input(shape=image_shape)

x1 = pop_model(inputs_image, training=False)
x1 = tf.keras.layers.Conv2D.from_config(conv_1_config)(x1)
x1 = tf.keras.layers.BatchNormalization.from_config(conv_1_bn_config)(x1)
x1 = tf.keras.layers.ReLU.from_config(out_relu_config)(x1)
x1 = tf.keras.layers.GlobalAveragePooling2D()(x1)
x1 = tf.keras.layers.Flatten()(x1)

#branch 2 for the non-image features
x2 = tf.keras.layers.Dense(10)(inputs_feature)
x2 = tf.keras.layers.Activation("relu")(x2)
x2 = tf.keras.layers.Dropout(0.5)(x2)

#conatenate the features
x = tf.keras.layers.Concatenate()([x1, x2])
x = tf.keras.layers.Activation('relu')(x)

# create output layer
x = tf.keras.layers.Dense(output_size)(x)
x = tf.keras.layers.Activation("sigmoid", name='predicted_observations')(x)

# create model class
model = tf.keras.Model(inputs=[inputs_feature, inputs_image],
outputs=x,
name = 'MobileNetv2_keras')

return model

# keras standard DenseNet121
def DenseNet121_keras(output_size,
not_transfer=False,
Expand Down Expand Up @@ -374,6 +483,7 @@ def ResNet152_keras(output_size,
"ResNet152_new": ResNet152_new,
"DenseNet121_new": DenseNet121_new,
"MobileNetv2_keras": MobileNetv2_keras,
"MobileNetv2_pop1": MobileNetv2_pop1,
"DenseNet121_keras": DenseNet121_keras,
"ResNet152_keras": ResNet152_keras
}

0 comments on commit 452bba0

Please sign in to comment.