Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates Keras CV - Classification guide for Keras 3 #1661

Merged
merged 5 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
86 changes: 63 additions & 23 deletions guides/ipynb/keras_cv/classification_with_keras_cv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
"- Fine-tuning a pretrained backbone\n",
"- Training a image classifier from scratch\n",
"\n",
"KerasCV uses Keras 3 to work with any of TensorFlow, PyTorch or Jax. In the\n",
"guide below, we will use the `jax` backend. This guide runs in\n",
"TensorFlow or PyTorch backends with zero changes, simply update the\n",
"`KERAS_BACKEND` below.\n",
"\n",
"We use Professor Keras, the official Keras mascot, as a\n",
"visual reference for the complexity of the material:\n",
"\n",
Expand All @@ -47,18 +52,38 @@
},
"outputs": [],
"source": [
"!pip install -q --upgrade keras-cv\n",
"!pip install -q --upgrade keras # Upgrade to Keras 3."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # @param [\"tensorflow\", \"jax\", \"torch\"]\n",
"\n",
"import json\n",
"import math\n",
"import keras_cv\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import numpy as np\n",
"\n",
"import keras\n",
"from keras import losses\n",
"import numpy as np\n",
"from keras import ops\n",
"from keras import optimizers\n",
"from tensorflow.keras.optimizers import schedules\n",
"from keras.optimizers import schedules\n",
"from keras import metrics\n",
"\n",
"import keras_cv\n",
"\n",
"# Import tensorflow for `tf.data` and its preprocessing functions\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
""
]
},
Expand Down Expand Up @@ -126,11 +151,11 @@
},
"outputs": [],
"source": [
"filepath = tf.keras.utils.get_file(origin=\"https://i.imgur.com/9i63gLN.jpg\")\n",
"filepath = keras.utils.get_file(origin=\"https://i.imgur.com/9i63gLN.jpg\")\n",
"image = keras.utils.load_img(filepath)\n",
"image = np.array(image)\n",
"keras_cv.visualization.plot_image_gallery(\n",
" [image], rows=1, cols=1, value_range=(0, 255), show=True, scale=4\n",
" np.array([image]), rows=1, cols=1, value_range=(0, 255), show=True, scale=4\n",
")"
]
},
Expand Down Expand Up @@ -325,7 +350,7 @@
")\n",
"model.compile(\n",
" loss=\"categorical_crossentropy\",\n",
" optimizer=tf.optimizers.SGD(learning_rate=0.01),\n",
" optimizer=keras.optimizers.SGD(learning_rate=0.01),\n",
" metrics=[\"accuracy\"],\n",
")"
]
Expand Down Expand Up @@ -618,8 +643,10 @@
"source": [
"rand_augment = keras_cv.layers.RandAugment(\n",
" augmentations_per_image=3,\n",
" magnitude=0.3,\n",
" value_range=(0, 255),\n",
" magnitude=0.3,\n",
" magnitude_stddev=0.2,\n",
" rate=1.0,\n",
")\n",
"augmenters += [rand_augment]\n",
"\n",
Expand Down Expand Up @@ -793,8 +820,18 @@
},
"outputs": [],
"source": [
"augmenter = keras.Sequential(augmenters)\n",
"train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)\n",
"\n",
"def create_augmenter_fn(augmenters):\n",
" def augmenter_fn(inputs):\n",
" for augmenter in augmenters:\n",
" inputs = augmenter(inputs)\n",
" return inputs\n",
"\n",
" return augmenter_fn\n",
"\n",
"\n",
"augmenter_fn = create_augmenter_fn(augmenters)\n",
"train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf.data.AUTOTUNE)\n",
"\n",
"image_batch = next(iter(train_ds.take(1)))[\"images\"]\n",
"keras_cv.visualization.plot_image_gallery(\n",
Expand Down Expand Up @@ -913,23 +950,26 @@
" * target_lr\n",
" * (\n",
" 1\n",
" + tf.cos(\n",
" tf.constant(math.pi)\n",
" * tf.cast(global_step - warmup_steps - hold, tf.float32)\n",
" / float(total_steps - warmup_steps - hold)\n",
" + ops.cos(\n",
" math.pi\n",
" * ops.convert_to_tensor(\n",
" global_step - warmup_steps - hold, dtype=\"float32\"\n",
" )\n",
" / ops.convert_to_tensor(\n",
" total_steps - warmup_steps - hold, dtype=\"float32\"\n",
" )\n",
" )\n",
" )\n",
" )\n",
"\n",
" warmup_lr = tf.cast(target_lr * (global_step / warmup_steps), tf.float32)\n",
" target_lr = tf.cast(target_lr, tf.float32)\n",
" warmup_lr = target_lr * (global_step / warmup_steps)\n",
"\n",
" if hold > 0:\n",
" learning_rate = tf.where(\n",
" learning_rate = ops.where(\n",
" global_step > warmup_steps + hold, learning_rate, target_lr\n",
" )\n",
"\n",
" learning_rate = tf.where(global_step < warmup_steps, warmup_lr, learning_rate)\n",
" learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)\n",
" return learning_rate\n",
"\n",
"\n",
Expand All @@ -952,7 +992,7 @@
" hold=self.hold,\n",
" )\n",
"\n",
" return tf.where(step > self.total_steps, 0.0, lr, name=\"learning_rate\")\n",
" return ops.where(step > self.total_steps, 0.0, lr)\n",
""
]
},
Expand Down Expand Up @@ -989,7 +1029,7 @@
" hold=hold_steps,\n",
")\n",
"optimizer = optimizers.SGD(\n",
" decay=5e-4,\n",
" weight_decay=5e-4,\n",
" learning_rate=schedule,\n",
" momentum=0.9,\n",
")"
Expand All @@ -1015,7 +1055,7 @@
},
"outputs": [],
"source": [
"backbone = keras_cv.models.EfficientNetV2B1Backbone()\n",
"backbone = keras_cv.models.EfficientNetV2B0Backbone()\n",
"model = keras.Sequential(\n",
" [\n",
" backbone,\n",
Expand Down Expand Up @@ -1166,4 +1206,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
77 changes: 55 additions & 22 deletions guides/keras_cv/classification_with_keras_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,43 @@
- Fine-tuning a pretrained backbone
- Training a image classifier from scratch

KerasCV uses Keras 3 to work with any of TensorFlow, PyTorch or Jax. In the
guide below, we will use the `jax` backend. This guide runs in
TensorFlow or PyTorch backends with zero changes, simply update the
`KERAS_BACKEND` below.

We use Professor Keras, the official Keras mascot, as a
visual reference for the complexity of the material:

![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png)
"""

"""shell
pip install -q --upgrade keras-cv
pip install -q --upgrade keras # Upgrade to Keras 3.
"""

import os

os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]

import json
import math
import keras_cv
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

import keras
from keras import losses
import numpy as np
from keras import ops
from keras import optimizers
from tensorflow.keras.optimizers import schedules
from keras.optimizers import schedules
from keras import metrics

import keras_cv

# Import tensorflow for `tf.data` and its preprocessing functions
import tensorflow as tf
import tensorflow_datasets as tfds


"""
## Inference with a pretrained classifier
Expand Down Expand Up @@ -77,11 +95,11 @@
Now that our classifier is built, let's apply it to this cute cat picture!
"""

filepath = tf.keras.utils.get_file(origin="https://i.imgur.com/9i63gLN.jpg")
filepath = keras.utils.get_file(origin="https://i.imgur.com/9i63gLN.jpg")
image = keras.utils.load_img(filepath)
image = np.array(image)
keras_cv.visualization.plot_image_gallery(
[image], rows=1, cols=1, value_range=(0, 255), show=True, scale=4
np.array([image]), rows=1, cols=1, value_range=(0, 255), show=True, scale=4
)

"""
Expand Down Expand Up @@ -187,7 +205,7 @@ def preprocess_inputs(image, label):
)
model.compile(
loss="categorical_crossentropy",
optimizer=tf.optimizers.SGD(learning_rate=0.01),
optimizer=keras.optimizers.SGD(learning_rate=0.01),
metrics=["accuracy"],
)

Expand Down Expand Up @@ -371,8 +389,10 @@ def package_inputs(image, label):
"""
rand_augment = keras_cv.layers.RandAugment(
augmentations_per_image=3,
magnitude=0.3,
value_range=(0, 255),
magnitude=0.3,
magnitude_stddev=0.2,
rate=1.0,
)
augmenters += [rand_augment]

Expand Down Expand Up @@ -471,8 +491,18 @@ def package_inputs(image, label):
Now let's apply our final augmenter to the training data:
"""

augmenter = keras.Sequential(augmenters)
train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)

def create_augmenter_fn(augmenters):
def augmenter_fn(inputs):
for augmenter in augmenters:
inputs = augmenter(inputs)
return inputs

return augmenter_fn


augmenter_fn = create_augmenter_fn(augmenters)
train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf.data.AUTOTUNE)

image_batch = next(iter(train_ds.take(1)))["images"]
keras_cv.visualization.plot_image_gallery(
Expand Down Expand Up @@ -548,23 +578,26 @@ def lr_warmup_cosine_decay(
* target_lr
* (
1
+ tf.cos(
tf.constant(math.pi)
* tf.cast(global_step - warmup_steps - hold, tf.float32)
/ float(total_steps - warmup_steps - hold)
+ ops.cos(
math.pi
* ops.convert_to_tensor(
global_step - warmup_steps - hold, dtype="float32"
)
/ ops.convert_to_tensor(
total_steps - warmup_steps - hold, dtype="float32"
)
)
)
)

warmup_lr = tf.cast(target_lr * (global_step / warmup_steps), tf.float32)
target_lr = tf.cast(target_lr, tf.float32)
warmup_lr = target_lr * (global_step / warmup_steps)

if hold > 0:
learning_rate = tf.where(
learning_rate = ops.where(
global_step > warmup_steps + hold, learning_rate, target_lr
)

learning_rate = tf.where(global_step < warmup_steps, warmup_lr, learning_rate)
learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)
return learning_rate


Expand All @@ -587,7 +620,7 @@ def __call__(self, step):
hold=self.hold,
)

return tf.where(step > self.total_steps, 0.0, lr, name="learning_rate")
return ops.where(step > self.total_steps, 0.0, lr)


"""
Expand All @@ -610,7 +643,7 @@ def __call__(self, step):
hold=hold_steps,
)
optimizer = optimizers.SGD(
decay=5e-4,
weight_decay=5e-4,
learning_rate=schedule,
momentum=0.9,
)
Expand All @@ -621,7 +654,7 @@ def __call__(self, step):
Note that this preset does not come with any pretrained weights.
"""

backbone = keras_cv.models.EfficientNetV2B1Backbone()
backbone = keras_cv.models.EfficientNetV2B0Backbone()
model = keras.Sequential(
[
backbone,
Expand Down
Loading