Skip to content

Commit

Permalink
Update Keras CV for Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb committed Nov 29, 2023
1 parent 44d023e commit 916402f
Show file tree
Hide file tree
Showing 15 changed files with 293 additions and 92 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
84 changes: 62 additions & 22 deletions guides/ipynb/keras_cv/classification_with_keras_cv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
"- Fine-tuning a pretrained backbone\n",
"- Training a image classifier from scratch\n",
"\n",
"KerasCV uses Keras 3 to work with any of TensorFlow, Pytorch and Jax. In the\n",
"guide below, we will use the `jax` backend. This guide runs in\n",
"TensorFlow or PyTorch backends with zero changes, simply update the\n",
"`KERAS_BACKEND` below.\n",
"\n",
"We use Professor Keras, the official Keras mascot, as a\n",
"visual reference for the complexity of the material:\n",
"\n",
Expand All @@ -47,18 +52,38 @@
},
"outputs": [],
"source": [
"!pip install -q --upgrade keras-cv\n",
"!pip install -q --upgrade keras # Upgrade to Keras 3."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # @param [\"tensorflow\", \"jax\", \"torch\"]\n",
"\n",
"import json\n",
"import math\n",
"import keras_cv\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import numpy as np\n",
"\n",
"import keras\n",
"from keras import losses\n",
"import numpy as np\n",
"from keras import ops\n",
"from keras import optimizers\n",
"from tensorflow.keras.optimizers import schedules\n",
"from keras.optimizers import schedules\n",
"from keras import metrics\n",
"\n",
"import keras_cv\n",
"\n",
"# Import tensorflow for `tf.data` and its preprocessing functions\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
""
]
},
Expand Down Expand Up @@ -126,11 +151,11 @@
},
"outputs": [],
"source": [
"filepath = tf.keras.utils.get_file(origin=\"https://i.imgur.com/9i63gLN.jpg\")\n",
"filepath = keras.utils.get_file(origin=\"https://i.imgur.com/9i63gLN.jpg\")\n",
"image = keras.utils.load_img(filepath)\n",
"image = np.array(image)\n",
"keras_cv.visualization.plot_image_gallery(\n",
" [image], rows=1, cols=1, value_range=(0, 255), show=True, scale=4\n",
" np.array([image]), rows=1, cols=1, value_range=(0, 255), show=True, scale=4\n",
")"
]
},
Expand Down Expand Up @@ -325,7 +350,7 @@
")\n",
"model.compile(\n",
" loss=\"categorical_crossentropy\",\n",
" optimizer=tf.optimizers.SGD(learning_rate=0.01),\n",
" optimizer=keras.optimizers.SGD(learning_rate=0.01),\n",
" metrics=[\"accuracy\"],\n",
")"
]
Expand Down Expand Up @@ -618,8 +643,10 @@
"source": [
"rand_augment = keras_cv.layers.RandAugment(\n",
" augmentations_per_image=3,\n",
" magnitude=0.3,\n",
" value_range=(0, 255),\n",
" magnitude=0.3,\n",
" magnitude_stddev=0.2,\n",
" rate=1.0,\n",
")\n",
"augmenters += [rand_augment]\n",
"\n",
Expand Down Expand Up @@ -793,8 +820,18 @@
},
"outputs": [],
"source": [
"augmenter = keras.Sequential(augmenters)\n",
"train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)\n",
"\n",
"def create_augmenter_fn(augmenters):\n",
" def augmenter_fn(inputs):\n",
" for augmenter in augmenters:\n",
" inputs = augmenter(inputs)\n",
" return inputs\n",
"\n",
" return augmenter_fn\n",
"\n",
"\n",
"augmenter_fn = create_augmenter_fn(augmenters)\n",
"train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf.data.AUTOTUNE)\n",
"\n",
"image_batch = next(iter(train_ds.take(1)))[\"images\"]\n",
"keras_cv.visualization.plot_image_gallery(\n",
Expand Down Expand Up @@ -913,23 +950,26 @@
" * target_lr\n",
" * (\n",
" 1\n",
" + tf.cos(\n",
" tf.constant(math.pi)\n",
" * tf.cast(global_step - warmup_steps - hold, tf.float32)\n",
" / float(total_steps - warmup_steps - hold)\n",
" + ops.cos(\n",
" math.pi\n",
" * ops.convert_to_tensor(\n",
" global_step - warmup_steps - hold, dtype=\"float32\"\n",
" )\n",
" / ops.convert_to_tensor(\n",
" total_steps - warmup_steps - hold, dtype=\"float32\"\n",
" )\n",
" )\n",
" )\n",
" )\n",
"\n",
" warmup_lr = tf.cast(target_lr * (global_step / warmup_steps), tf.float32)\n",
" target_lr = tf.cast(target_lr, tf.float32)\n",
" warmup_lr = target_lr * (global_step / warmup_steps)\n",
"\n",
" if hold > 0:\n",
" learning_rate = tf.where(\n",
" learning_rate = ops.where(\n",
" global_step > warmup_steps + hold, learning_rate, target_lr\n",
" )\n",
"\n",
" learning_rate = tf.where(global_step < warmup_steps, warmup_lr, learning_rate)\n",
" learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)\n",
" return learning_rate\n",
"\n",
"\n",
Expand All @@ -952,7 +992,7 @@
" hold=self.hold,\n",
" )\n",
"\n",
" return tf.where(step > self.total_steps, 0.0, lr, name=\"learning_rate\")\n",
" return ops.where(step > self.total_steps, 0.0, lr)\n",
""
]
},
Expand Down Expand Up @@ -989,7 +1029,7 @@
" hold=hold_steps,\n",
")\n",
"optimizer = optimizers.SGD(\n",
" decay=5e-4,\n",
" weight_decay=5e-4,\n",
" learning_rate=schedule,\n",
" momentum=0.9,\n",
")"
Expand All @@ -1015,7 +1055,7 @@
},
"outputs": [],
"source": [
"backbone = keras_cv.models.EfficientNetV2B1Backbone()\n",
"backbone = keras_cv.models.EfficientNetV2B0Backbone()\n",
"model = keras.Sequential(\n",
" [\n",
" backbone,\n",
Expand Down
39 changes: 22 additions & 17 deletions guides/keras_cv/classification_with_keras_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@

import keras
from keras import losses
from keras import ops
from keras import optimizers
from keras.optimizers import schedules
from keras import metrics

import keras_cv

# Import tensorflow for `tf.data` and its preprocessing functions
import tensorflow as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -489,14 +491,14 @@ def package_inputs(image, label):
Now let's apply our final augmenter to the training data:
"""

def create_augmenter_fn(augmenters):

def augmenter_fn(inputs):
for augmenter in augmenters:
inputs = augmenter(inputs)
return inputs
def create_augmenter_fn(augmenters):
def augmenter_fn(inputs):
for augmenter in augmenters:
inputs = augmenter(inputs)
return inputs

return augmenter_fn
return augmenter_fn


augmenter_fn = create_augmenter_fn(augmenters)
Expand Down Expand Up @@ -576,23 +578,26 @@ def lr_warmup_cosine_decay(
* target_lr
* (
1
+ tf.cos(
tf.constant(math.pi)
* tf.cast(global_step - warmup_steps - hold, tf.float32)
/ float(total_steps - warmup_steps - hold)
+ ops.cos(
math.pi
* ops.convert_to_tensor(
global_step - warmup_steps - hold, dtype="float32"
)
/ ops.convert_to_tensor(
total_steps - warmup_steps - hold, dtype="float32"
)
)
)
)

warmup_lr = tf.cast(target_lr * (global_step / warmup_steps), tf.float32)
target_lr = tf.cast(target_lr, tf.float32)
warmup_lr = target_lr * (global_step / warmup_steps)

if hold > 0:
learning_rate = tf.where(
learning_rate = ops.where(
global_step > warmup_steps + hold, learning_rate, target_lr
)

learning_rate = tf.where(global_step < warmup_steps, warmup_lr, learning_rate)
learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)
return learning_rate


Expand All @@ -615,7 +620,7 @@ def __call__(self, step):
hold=self.hold,
)

return tf.where(step > self.total_steps, 0.0, lr, name="learning_rate")
return ops.where(step > self.total_steps, 0.0, lr)


"""
Expand All @@ -638,7 +643,7 @@ def __call__(self, step):
hold=hold_steps,
)
optimizer = optimizers.SGD(
decay=5e-4,
weight_decay=5e-4,
learning_rate=schedule,
momentum=0.9,
)
Expand All @@ -649,7 +654,7 @@ def __call__(self, step):
Note that this preset does not come with any pretrained weights.
"""

backbone = keras_cv.models.EfficientNetV2B1Backbone()
backbone = keras_cv.models.EfficientNetV2B0Backbone()
model = keras.Sequential(
[
backbone,
Expand Down
Loading

0 comments on commit 916402f

Please sign in to comment.