Skip to content

Commit

Permalink
Enhanced Deep Residual Networks for single-image super-resolution - K…
Browse files Browse the repository at this point in the history
…eras 3 migration (Only Tensorflow Backend) (#1920)

* Keras 3 migration

* trim output
  • Loading branch information
chunduriv authored Oct 23, 2024
1 parent 3117146 commit 695a0b5
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 311 deletions.
54 changes: 33 additions & 21 deletions examples/vision/edsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Enhanced Deep Residual Networks for single-image super-resolution
Author: Gitesh Chawda
Date created: 2022/04/07
Last modified: 2022/04/07
Last modified: 2024/08/27
Description: Training an EDSR model on the DIV2K Dataset.
Accelerator: GPU
"""
Expand Down Expand Up @@ -40,14 +40,18 @@
"""
## Imports
"""
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import layers
from keras import ops

AUTOTUNE = tf.data.AUTOTUNE

Expand Down Expand Up @@ -81,15 +85,15 @@ def flip_left_right(lowres_img, highres_img):
"""Flips Images to left and right."""

# Outputs random values from a uniform distribution in between 0 to 1
rn = tf.random.uniform(shape=(), maxval=1)
rn = keras.random.uniform(shape=(), maxval=1)
# If rn is less than 0.5 it returns original lowres_img and highres_img
# If rn is greater than 0.5 it returns flipped image
return tf.cond(
return ops.cond(
rn < 0.5,
lambda: (lowres_img, highres_img),
lambda: (
tf.image.flip_left_right(lowres_img),
tf.image.flip_left_right(highres_img),
ops.flip(lowres_img),
ops.flip(highres_img),
),
)

Expand All @@ -98,7 +102,9 @@ def random_rotate(lowres_img, highres_img):
"""Rotates Images by 90 degrees."""

# Outputs random values from uniform distribution in between 0 to 4
rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
rn = ops.cast(
keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32"
)
# Here rn signifies number of times the image(s) are rotated by 90 degrees
return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)

Expand All @@ -110,13 +116,19 @@ def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
high resolution images: 96x96
"""
lowres_crop_size = hr_crop_size // scale # 96//4=24
lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)
lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)

lowres_width = tf.random.uniform(
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32
lowres_width = ops.cast(
keras.random.uniform(
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
),
dtype="int32",
)
lowres_height = tf.random.uniform(
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32
lowres_height = ops.cast(
keras.random.uniform(
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
),
dtype="int32",
)

highres_width = lowres_width * scale
Expand Down Expand Up @@ -218,7 +230,7 @@ def PSNR(super_resolution, high_resolution):
"""


class EDSRModel(tf.keras.Model):
class EDSRModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
Expand All @@ -242,16 +254,16 @@ def train_step(self, data):

def predict_step(self, x):
# Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)
x = ops.cast(tf.expand_dims(x, axis=0), dtype="float32")
# Passing low resolution image to model
super_resolution_img = self(x, training=False)
# Clips the tensor from min(0) to max(255)
super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)
super_resolution_img = ops.clip(super_resolution_img, 0, 255)
# Rounds the values of a tensor to the nearest integer
super_resolution_img = tf.round(super_resolution_img)
super_resolution_img = ops.round(super_resolution_img)
# Removes dimensions of size 1 from the shape of a tensor and converting to uint8
super_resolution_img = tf.squeeze(
tf.cast(super_resolution_img, tf.uint8), axis=0
super_resolution_img = ops.squeeze(
ops.cast(super_resolution_img, dtype="uint8"), axis=0
)
return super_resolution_img

Expand All @@ -267,9 +279,9 @@ def ResBlock(inputs):
# Upsampling Block
def Upsampling(inputs, factor=2, **kwargs):
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
x = tf.nn.depth_to_space(x, block_size=factor)
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
x = tf.nn.depth_to_space(x, block_size=factor)
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
return x


Expand Down
Binary file added examples/vision/img/edsr/edsr_11_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_11_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
88 changes: 51 additions & 37 deletions examples/vision/ipynb/edsr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** Gitesh Chawda<br>\n",
"**Date created:** 2022/04/07<br>\n",
"**Last modified:** 2022/04/07<br>\n",
"**Last modified:** 2024/08/27<br>\n",
"**Description:** Training an EDSR model on the DIV2K Dataset."
]
},
Expand Down Expand Up @@ -39,7 +39,7 @@
"you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five\n",
"best-performing super-resolution methods based on PSNR scores. However, it has more\n",
"parameters and requires more computational power than other approaches.\n",
"It has a PSNR value (≈34db) that is slightly higher than ESPCN (≈32db).\n",
"It has a PSNR value (\u224834db) that is slightly higher than ESPCN (\u224832db).\n",
"As per the survey paper, EDSR performs better than ESPCN.\n",
"\n",
"Paper:\n",
Expand All @@ -60,19 +60,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"\n",
"AUTOTUNE = tf.data.AUTOTUNE"
]
Expand All @@ -93,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -123,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -134,15 +139,15 @@
" \"\"\"Flips Images to left and right.\"\"\"\n",
"\n",
" # Outputs random values from a uniform distribution in between 0 to 1\n",
" rn = tf.random.uniform(shape=(), maxval=1)\n",
" rn = keras.random.uniform(shape=(), maxval=1)\n",
" # If rn is less than 0.5 it returns original lowres_img and highres_img\n",
" # If rn is greater than 0.5 it returns flipped image\n",
" return tf.cond(\n",
" return ops.cond(\n",
" rn < 0.5,\n",
" lambda: (lowres_img, highres_img),\n",
" lambda: (\n",
" tf.image.flip_left_right(lowres_img),\n",
" tf.image.flip_left_right(highres_img),\n",
" ops.flip(lowres_img),\n",
" ops.flip(highres_img),\n",
" ),\n",
" )\n",
"\n",
Expand All @@ -151,7 +156,9 @@
" \"\"\"Rotates Images by 90 degrees.\"\"\"\n",
"\n",
" # Outputs random values from uniform distribution in between 0 to 4\n",
" rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)\n",
" rn = ops.cast(\n",
" keras.random.uniform(shape=(), maxval=4, dtype=\"float32\"), dtype=\"int32\"\n",
" )\n",
" # Here rn signifies number of times the image(s) are rotated by 90 degrees\n",
" return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)\n",
"\n",
Expand All @@ -163,13 +170,19 @@
" high resolution images: 96x96\n",
" \"\"\"\n",
" lowres_crop_size = hr_crop_size // scale # 96//4=24\n",
" lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)\n",
" lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)\n",
"\n",
" lowres_width = tf.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32\n",
" lowres_width = ops.cast(\n",
" keras.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=\"float32\"\n",
" ),\n",
" dtype=\"int32\",\n",
" )\n",
" lowres_height = tf.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32\n",
" lowres_height = ops.cast(\n",
" keras.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=\"float32\"\n",
" ),\n",
" dtype=\"int32\",\n",
" )\n",
"\n",
" highres_width = lowres_width * scale\n",
Expand All @@ -184,7 +197,8 @@
" highres_width : highres_width + hr_crop_size,\n",
" ] # 96x96\n",
"\n",
" return lowres_img_cropped, highres_img_cropped\n"
" return lowres_img_cropped, highres_img_cropped\n",
""
]
},
{
Expand All @@ -202,15 +216,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def dataset_object(dataset_cache, training=True):\n",
"\n",
" ds = dataset_cache\n",
" ds = ds.map(\n",
" lambda lowres, highres: random_crop(lowres, highres, scale=4),\n",
Expand Down Expand Up @@ -248,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -277,7 +290,8 @@
" \"\"\"Compute the peak signal-to-noise ratio, measures quality of image.\"\"\"\n",
" # Max value of pixel is 255\n",
" psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]\n",
" return psnr_value\n"
" return psnr_value\n",
""
]
},
{
Expand Down Expand Up @@ -305,14 +319,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class EDSRModel(tf.keras.Model):\n",
"class EDSRModel(keras.Model):\n",
" def train_step(self, data):\n",
" # Unpack the data. Its structure depends on your model and\n",
" # on what you pass to `fit()`.\n",
Expand All @@ -336,16 +350,16 @@
"\n",
" def predict_step(self, x):\n",
" # Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast\n",
" x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)\n",
" x = ops.cast(tf.expand_dims(x, axis=0), dtype=\"float32\")\n",
" # Passing low resolution image to model\n",
" super_resolution_img = self(x, training=False)\n",
" # Clips the tensor from min(0) to max(255)\n",
" super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)\n",
" super_resolution_img = ops.clip(super_resolution_img, 0, 255)\n",
" # Rounds the values of a tensor to the nearest integer\n",
" super_resolution_img = tf.round(super_resolution_img)\n",
" super_resolution_img = ops.round(super_resolution_img)\n",
" # Removes dimensions of size 1 from the shape of a tensor and converting to uint8\n",
" super_resolution_img = tf.squeeze(\n",
" tf.cast(super_resolution_img, tf.uint8), axis=0\n",
" super_resolution_img = ops.squeeze(\n",
" ops.cast(super_resolution_img, dtype=\"uint8\"), axis=0\n",
" )\n",
" return super_resolution_img\n",
"\n",
Expand All @@ -360,10 +374,10 @@
"\n",
"# Upsampling Block\n",
"def Upsampling(inputs, factor=2, **kwargs):\n",
" x = layers.Conv2D(64 * (factor ** 2), 3, padding=\"same\", **kwargs)(inputs)\n",
" x = tf.nn.depth_to_space(x, block_size=factor)\n",
" x = layers.Conv2D(64 * (factor ** 2), 3, padding=\"same\", **kwargs)(x)\n",
" x = tf.nn.depth_to_space(x, block_size=factor)\n",
" x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(inputs)\n",
" x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n",
" x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(x)\n",
" x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n",
" return x\n",
"\n",
"\n",
Expand Down Expand Up @@ -402,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -431,7 +445,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -473,7 +487,7 @@
"\n",
"| Trained Model | Demo |\n",
"| :--: | :--: |\n",
"| [![Generic badge](https://img.shields.io/badge/🤗%20Model-EDSR-red.svg)](https://huggingface.co/keras-io/EDSR) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-EDSR-red.svg)](https://huggingface.co/spaces/keras-io/EDSR) |"
"| [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Model-EDSR-red.svg)](https://huggingface.co/keras-io/EDSR) | [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Spaces-EDSR-red.svg)](https://huggingface.co/spaces/keras-io/EDSR) |"
]
}
],
Expand Down Expand Up @@ -506,4 +520,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 695a0b5

Please sign in to comment.