Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Highly accurate boundaries segmentation using BASNet to keras 3.0 #2038

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions examples/vision/basnet_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from glob import glob
import matplotlib.pyplot as plt

import keras_cv
import keras_hub
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if tf import needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use tf.image.ssim and need the tf import for that. We actually even mention that in line 37. Thanks.

import keras
from keras import layers, ops
Expand Down Expand Up @@ -228,15 +228,19 @@ def segmentation_head(x_input, out_classes, final_size):
return x


def get_resnet_block(_resnet, block_num):
"""Extract and return ResNet-34 block."""
resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.
def get_resnet_block(resnet, block_num):
"""Extract and return a ResNet-34 block."""
extractor_levels = ["P2", "P3", "P4", "P5"]
num_blocks = resnet.stackwise_num_blocks
if block_num == 0:
x = resnet.get_layer("pool1_pool").output
else:
x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]
y = resnet.get_layer(f"stack{block_num}_block{num_blocks[block_num]-1}_add").output
return keras.models.Model(
inputs=_resnet.get_layer(f"v2_stack_{block_num}_block1_1_conv").input,
outputs=_resnet.get_layer(
f"v2_stack_{block_num}_block{resnet_layers[block_num]}_add"
).output,
name=f"resnet34_block{block_num + 1}",
inputs=x,
outputs=y,
name=f"resnet_block{block_num + 1}",
)


Expand All @@ -262,8 +266,13 @@ def basnet_predict(input_shape, out_classes):
# -------------Encoder--------------
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)

resnet = keras_cv.models.ResNet34Backbone(
include_rescaling=False,
resnet = keras_hub.models.ResNetBackbone(
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 128, 256, 512],
stackwise_num_blocks=[3, 4, 6, 3],
stackwise_num_strides=[1, 2, 2, 2],
block_type="basic_block",
)

encoder_blocks = []
Expand Down Expand Up @@ -307,7 +316,7 @@ def basnet_predict(input_shape, out_classes):
for decoder_block in decoder_blocks
]

return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
return keras.models.Model(inputs=x_input, outputs=decoder_blocks)


"""
Expand Down Expand Up @@ -352,7 +361,7 @@ def basnet_rrm(base_model, out_classes):
# ------------- refined = coarse + residual
x = layers.Add()([x_input, x]) # Add prediction + refinement output

return keras.models.Model(inputs=base_model.input[0], outputs=x)
return keras.models.Model(inputs=[base_model.input], outputs=[x])


"""
Expand All @@ -375,7 +384,7 @@ def __init__(self, input_shape, out_classes):

# Activations.
output = [layers.Activation("sigmoid")(x) for x in output]
super().__init__(inputs=predict_model.input[0], outputs=output)
super().__init__(inputs=predict_model.input, outputs=output)

self.smooth = 1.0e-9
# Binary Cross Entropy loss.
Expand Down Expand Up @@ -453,9 +462,9 @@ def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
trainings parameters please check given link.
"""

"""shell
!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg
"""
import gdown

gdown.download(id="1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", output="basnet_weights.h5")


def normalize_output(prediction):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
177 changes: 88 additions & 89 deletions examples/vision/ipynb/basnet_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Hamid Ali](https://github.com/hamidriasat)<br>\n",
"**Date created:** 2023/05/30<br>\n",
"**Last modified:** 2024/10/02<br>\n",
"**Last modified:** 2025/01/24<br>\n",
"**Description:** Boundaries aware segmentation model trained on the DUTS dataset."
]
},
Expand Down Expand Up @@ -68,10 +68,12 @@
"from glob import glob\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import keras_cv\n",
"import keras_hub\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers, ops"
"from keras import layers, ops\n",
"\n",
"keras.config.disable_traceback_filtering()"
]
},
{
Expand Down Expand Up @@ -117,10 +119,11 @@
},
"outputs": [],
"source": [
"DATA_DIR = keras.utils.get_file(\n",
"data_dir = keras.utils.get_file(\n",
" origin=\"http://saliencydetection.net/duts/download/DUTS-TE.zip\",\n",
" extract=True,\n",
")\n",
"data_dir = os.path.join(data_dir, \"DUTS-TE\")\n",
"\n",
"\n",
"def load_paths(path, split_ratio):\n",
Expand Down Expand Up @@ -159,7 +162,9 @@
" batch_x, batch_y = [], []\n",
" for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):\n",
" x, y = self.preprocess(\n",
" self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes\n",
" self.image_paths[i],\n",
" self.mask_paths[i],\n",
" self.img_size,\n",
" )\n",
" batch_x.append(x)\n",
" batch_y.append(y)\n",
Expand All @@ -173,13 +178,13 @@
" x = (x / 255.0).astype(np.float32)\n",
" return x\n",
"\n",
" def preprocess(self, x_batch, y_batch, img_size, out_classes):\n",
" def preprocess(self, x_batch, y_batch, img_size):\n",
" images = self.read_image(x_batch, (img_size, img_size), mode=\"rgb\") # image\n",
" masks = self.read_image(y_batch, (img_size, img_size), mode=\"grayscale\") # mask\n",
" return images, masks\n",
"\n",
"\n",
"train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)\n",
"train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)\n",
"\n",
"train_dataset = Dataset(\n",
" train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True\n",
Expand Down Expand Up @@ -318,17 +323,20 @@
" return x\n",
"\n",
"\n",
"def get_resnet_block(_resnet, block_num):\n",
" \"\"\"Extract and return ResNet-34 block.\"\"\"\n",
" resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.\n",
"def get_resnet_block(resnet, block_num):\n",
" \"\"\"Extract and return a ResNet-34 block.\"\"\"\n",
" extractor_levels = [\"P2\", \"P3\", \"P4\", \"P5\"]\n",
" num_blocks = resnet.stackwise_num_blocks\n",
" if block_num == 0:\n",
" x = resnet.get_layer(\"pool1_pool\").output\n",
" else:\n",
" x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]\n",
" y = resnet.get_layer(f\"stack{block_num}_block{num_blocks[block_num]-1}_add\").output\n",
" return keras.models.Model(\n",
" inputs=_resnet.get_layer(f\"v2_stack_{block_num}_block1_1_conv\").input,\n",
" outputs=_resnet.get_layer(\n",
" f\"v2_stack_{block_num}_block{resnet_layers[block_num]}_add\"\n",
" ).output,\n",
" name=f\"resnet34_block{block_num + 1}\",\n",
" )\n",
""
" inputs=x,\n",
" outputs=y,\n",
" name=f\"resnet_block{block_num + 1}\",\n",
" )\n"
]
},
{
Expand Down Expand Up @@ -366,8 +374,13 @@
" # -------------Encoder--------------\n",
" x = layers.Conv2D(filters, kernel_size=(3, 3), padding=\"same\")(x_input)\n",
"\n",
" resnet = keras_cv.models.ResNet34Backbone(\n",
" include_rescaling=False,\n",
" resnet = keras_hub.models.ResNetBackbone(\n",
" input_conv_filters=[64],\n",
" input_conv_kernel_sizes=[7],\n",
" stackwise_num_filters=[64, 128, 256, 512],\n",
" stackwise_num_blocks=[3, 4, 6, 3],\n",
" stackwise_num_strides=[1, 2, 2, 2],\n",
" block_type=\"basic_block\",\n",
" )\n",
"\n",
" encoder_blocks = []\n",
Expand Down Expand Up @@ -411,8 +424,7 @@
" for decoder_block in decoder_blocks\n",
" ]\n",
"\n",
" return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)\n",
""
" return keras.models.Model(inputs=x_input, outputs=decoder_blocks)\n"
]
},
{
Expand Down Expand Up @@ -470,8 +482,7 @@
" # ------------- refined = coarse + residual\n",
" x = layers.Add()([x_input, x]) # Add prediction + refinement output\n",
"\n",
" return keras.models.Model(inputs=[base_model.input], outputs=[x])\n",
""
" return keras.models.Model(inputs=[base_model.input], outputs=[x])\n"
]
},
{
Expand All @@ -492,22 +503,56 @@
"outputs": [],
"source": [
"\n",
"def basnet(input_shape, out_classes):\n",
" \"\"\"BASNet, it's a combination of two modules\n",
" Prediction Module and Residual Refinement Module(RRM).\"\"\"\n",
"class BASNet(keras.Model):\n",
" def __init__(self, input_shape, out_classes):\n",
" \"\"\"BASNet, it's a combination of two modules\n",
" Prediction Module and Residual Refinement Module(RRM).\"\"\"\n",
"\n",
" # Prediction model.\n",
" predict_model = basnet_predict(input_shape, out_classes)\n",
" # Refinement model.\n",
" refine_model = basnet_rrm(predict_model, out_classes)\n",
"\n",
" output = refine_model.outputs # Combine outputs.\n",
" output.extend(predict_model.output)\n",
"\n",
" # Activations.\n",
" output = [layers.Activation(\"sigmoid\")(x) for x in output]\n",
" super().__init__(inputs=predict_model.input, outputs=output)\n",
"\n",
" self.smooth = 1.0e-9\n",
" # Binary Cross Entropy loss.\n",
" self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n",
" # Structural Similarity Index value.\n",
" self.ssim_value = tf.image.ssim\n",
" # Jaccard / IoU loss.\n",
" self.iou_value = self.calculate_iou\n",
"\n",
" def calculate_iou(\n",
" self,\n",
" y_true,\n",
" y_pred,\n",
" ):\n",
" \"\"\"Calculate intersection over union (IoU) between images.\"\"\"\n",
" intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n",
" union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n",
" union = union - intersection\n",
" return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n",
"\n",
" # Prediction model.\n",
" predict_model = basnet_predict(input_shape, out_classes)\n",
" # Refinement model.\n",
" refine_model = basnet_rrm(predict_model, out_classes)\n",
" def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):\n",
" total = 0.0\n",
" for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output\n",
" cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)\n",
"\n",
" output = refine_model.outputs # Combine outputs.\n",
" output.extend(predict_model.output)\n",
" ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n",
" ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n",
"\n",
" output = [layers.Activation(\"sigmoid\")(_) for _ in output] # Activations.\n",
" iou_value = self.iou_value(y_true, y_pred)\n",
" iou_loss = 1 - iou_value\n",
"\n",
" return keras.models.Model(inputs=[predict_model.input], outputs=output)\n",
""
" # Add all three losses.\n",
" total += cross_entropy_loss + ssim_loss + iou_loss\n",
" return total\n"
]
},
{
Expand All @@ -532,53 +577,14 @@
"outputs": [],
"source": [
"\n",
"class BasnetLoss(keras.losses.Loss):\n",
" \"\"\"BASNet hybrid loss.\"\"\"\n",
"\n",
" def __init__(self, **kwargs):\n",
" super().__init__(name=\"basnet_loss\", **kwargs)\n",
" self.smooth = 1.0e-9\n",
"\n",
" # Binary Cross Entropy loss.\n",
" self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n",
" # Structural Similarity Index value.\n",
" self.ssim_value = tf.image.ssim\n",
" # Jaccard / IoU loss.\n",
" self.iou_value = self.calculate_iou\n",
"\n",
" def calculate_iou(\n",
" self,\n",
" y_true,\n",
" y_pred,\n",
" ):\n",
" \"\"\"Calculate intersection over union (IoU) between images.\"\"\"\n",
" intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n",
" union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n",
" union = union - intersection\n",
" return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n",
"\n",
" def call(self, y_true, y_pred):\n",
" cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)\n",
"\n",
" ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n",
" ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n",
"\n",
" iou_value = self.iou_value(y_true, y_pred)\n",
" iou_loss = 1 - iou_value\n",
"\n",
" # Add all three losses.\n",
" return cross_entropy_loss + ssim_loss + iou_loss\n",
"\n",
"\n",
"basnet_model = basnet(\n",
"basnet_model = BASNet(\n",
" input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES\n",
") # Create model.\n",
"basnet_model.summary() # Show model summary.\n",
"\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)\n",
"# Compile model.\n",
"basnet_model.compile(\n",
" loss=BasnetLoss(),\n",
" optimizer=optimizer,\n",
" metrics=[keras.metrics.MeanAbsoluteError(name=\"mae\") for _ in basnet_model.outputs],\n",
")"
Expand Down Expand Up @@ -631,17 +637,10 @@
},
"outputs": [],
"source": [
"!!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import gdown\n",
"\n",
"gdown.download(id=\"1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg\", output=\"basnet_weights.h5\")\n",
"\n",
"\n",
"def normalize_output(prediction):\n",
" max_value = np.max(prediction)\n",
Expand Down Expand Up @@ -686,7 +685,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "evn1",
"language": "python",
"name": "python3"
},
Expand All @@ -700,9 +699,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading
Loading