From 27362331f16be77768e7c4c6bedade677ef34594 Mon Sep 17 00:00:00 2001 From: Sarvagya Malaviya <45961148+quantumalaviya@users.noreply.github.com> Date: Wed, 15 Jun 2022 02:19:34 +0530 Subject: [PATCH] Add AugMix (#407) * adding augmix * add augmix with demo * add augmix test * add augmix * Update aug_mix.py * Update aug_mix.py * Update aug_mix_demo.py * updating augmix with changes * init commit with review comments * fix linting issues * removing RandomShear assignment in constructor * review comments part 2 * review comments part 2 fix * review comments part 3 * change variable names * fix lint issues * edit serialization test * edit demo to use utils * Update augmix Co-authored-by: Luke Wood --- examples/layers/preprocessing/aug_mix_demo.py | 35 ++ keras_cv/layers/__init__.py | 1 + keras_cv/layers/preprocessing/__init__.py | 1 + keras_cv/layers/preprocessing/aug_mix.py | 313 ++++++++++++++++++ keras_cv/layers/preprocessing/aug_mix_test.py | 79 +++++ keras_cv/layers/preprocessing/fourier_mix.py | 2 +- keras_cv/layers/preprocessing/mix_up.py | 2 +- keras_cv/layers/serialization_test.py | 12 + keras_cv/utils/preprocessing.py | 73 ++++ 9 files changed, 516 insertions(+), 2 deletions(-) create mode 100644 examples/layers/preprocessing/aug_mix_demo.py create mode 100644 keras_cv/layers/preprocessing/aug_mix.py create mode 100644 keras_cv/layers/preprocessing/aug_mix_test.py diff --git a/examples/layers/preprocessing/aug_mix_demo.py b/examples/layers/preprocessing/aug_mix_demo.py new file mode 100644 index 0000000000..627fd9b7c3 --- /dev/null +++ b/examples/layers/preprocessing/aug_mix_demo.py @@ -0,0 +1,35 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""aug_mix_demo.py shows how to use the AugMix preprocessing layer. + +Operates on the oxford_flowers102 dataset. In this script the flowers +are loaded, then are passed through the preprocessing layers. +Finally, they are shown using matplotlib. +""" + +import demo_utils +import tensorflow as tf + +from keras_cv import layers + + +def main(): + augmix = layers.AugMix(value_range=[0, 255]) + ds = demo_utils.load_oxford_dataset() + ds = ds.map(augmix, num_parallel_calls=tf.data.AUTOTUNE) + demo_utils.visualize_dataset(ds) + + +if __name__ == "__main__": + main() diff --git a/keras_cv/layers/__init__.py b/keras_cv/layers/__init__.py index 27a848d777..3d0cf62286 100644 --- a/keras_cv/layers/__init__.py +++ b/keras_cv/layers/__init__.py @@ -26,6 +26,7 @@ from tensorflow.keras.layers import Rescaling from tensorflow.keras.layers import Resizing +from keras_cv.layers.preprocessing.aug_mix import AugMix from keras_cv.layers.preprocessing.auto_contrast import AutoContrast from keras_cv.layers.preprocessing.channel_shuffle import ChannelShuffle from keras_cv.layers.preprocessing.cut_mix import CutMix diff --git a/keras_cv/layers/preprocessing/__init__.py b/keras_cv/layers/preprocessing/__init__.py index 1985336677..e852ef62f8 100644 --- a/keras_cv/layers/preprocessing/__init__.py +++ b/keras_cv/layers/preprocessing/__init__.py @@ -29,6 +29,7 @@ from tensorflow.keras.layers import Rescaling from tensorflow.keras.layers import Resizing +from keras_cv.layers.preprocessing.aug_mix import AugMix from keras_cv.layers.preprocessing.auto_contrast import AutoContrast from keras_cv.layers.preprocessing.channel_shuffle import ChannelShuffle from keras_cv.layers.preprocessing.cut_mix import CutMix diff --git a/keras_cv/layers/preprocessing/aug_mix.py b/keras_cv/layers/preprocessing/aug_mix.py new file mode 100644 index 0000000000..4d81a11f38 --- /dev/null +++ b/keras_cv/layers/preprocessing/aug_mix.py @@ -0,0 +1,313 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv import layers +from keras_cv.utils import preprocessing + + +@tf.keras.utils.register_keras_serializable(package="keras_cv") +class AugMix(tf.keras.__internal__.layers.BaseImageAugmentationLayer): + """Performs the AugMix data augmentation technique. + + AugMix aims to produce images with variety while preserving the + image semantics and local statistics. During the augmentation process, each image + is augmented `num_chains` different ways, each way consisting of `chain_depth` + augmentations. Augmentations are sampled from the list: translation, shearing, + rotation, posterization, histogram equalization, solarization and auto contrast. + The results of each chain are then mixed together with the original + image based on random samples from a Dirichlet distribution. + + Args: + value_range: the range of values the incoming images will have. + Represented as a two number tuple written (low, high). + This is typically either `(0, 1)` or `(0, 255)` depending + on how your preprocessing pipeline is setup. + severity: A tuple of two floats, a single float or a `keras_cv.FactorSampler`. + A value is sampled from the provided range. If a float is passed, the + range is interpreted as `(0, severity)`. This value represents the + level of strength of augmentations and is in the range [0, 1]. + Defaults to 0.3. + num_chains: an integer representing the number of different chains to + be mixed. Defaults to 3. + chain_depth: an integer or range representing the number of transformations in + the chains. Defaults to [1,3]. + alpha: a float value used as the probability coefficients for the + Beta and Dirichlet distributions. Defaults to 1.0. + seed: Integer. Used to create a random seed. + + References: + [AugMix paper](https://arxiv.org/pdf/1912.02781) + [Official Code](https://github.com/google-research/augmix) + [Unoffial TF Code](https://github.com/szacho/augmix-tf) + + Sample Usage: + ```python + (images, labels), _ = tf.keras.datasets.cifar10.load_data() + augmix = keras_cv.layers.AugMix([0, 255]) + augmented_images = augmix(images[:100]) + ``` + """ + + def __init__( + self, + value_range, + severity=0.3, + num_chains=3, + chain_depth=[1, 3], + alpha=1.0, + seed=None, + **kwargs, + ): + super().__init__(seed=seed, **kwargs) + self.value_range = value_range + self.num_chains = num_chains + self.chain_depth = chain_depth + + if isinstance(self.chain_depth, int): + self.chain_depth = [self.chain_depth, self.chain_depth] + + self.alpha = alpha + self.seed = seed + self.auto_vectorize = False + self.severity = severity + self.severity_factor = preprocessing.parse_factor( + self.severity, + min_value=0.01, + max_value=1.0, + param_name="severity", + seed=self.seed, + ) + + # initialize layers + self.auto_contrast = layers.AutoContrast(value_range=self.value_range) + self.equalize = layers.Equalization(value_range=self.value_range) + + @staticmethod + def _sample_from_dirichlet(alpha): + gamma_sample = tf.random.gamma(shape=(), alpha=alpha) + return gamma_sample / tf.reduce_sum(gamma_sample, axis=-1, keepdims=True) + + @staticmethod + def _sample_from_beta(alpha, beta): + sample_alpha = tf.random.gamma((), 1.0, beta=alpha) + sample_beta = tf.random.gamma((), 1.0, beta=beta) + return sample_alpha / (sample_alpha + sample_beta) + + def _sample_depth(self): + return self._random_generator.random_uniform( + shape=(), + minval=self.chain_depth[0], + maxval=self.chain_depth[1] + 1, + dtype=tf.int32, + ) + + def _loop_on_depth(self, depth_level, image_aug): + op_index = self._random_generator.random_uniform( + shape=(), minval=0, maxval=8, dtype=tf.int32 + ) + image_aug = self._apply_op(image_aug, op_index) + depth_level += 1 + return depth_level, image_aug + + def _loop_on_width(self, image, chain_mixing_weights, curr_chain, result): + image_aug = tf.identity(image) + chain_depth = self._sample_depth() + + depth_level = tf.constant([0], dtype=tf.int32) + depth_level, image_aug = tf.while_loop( + lambda depth_level, image_aug: tf.less(depth_level, chain_depth), + self._loop_on_depth, + [depth_level, image_aug], + ) + result += tf.gather(chain_mixing_weights, curr_chain) * image_aug + curr_chain += 1 + return image, chain_mixing_weights, curr_chain, result + + def _auto_contrast(self, image): + return self.auto_contrast(image) + + def _equalize(self, image): + return self.equalize(image) + + def _posterize(self, image): + image = preprocessing.transform_value_range( + images=image, + original_range=self.value_range, + target_range=[0, 255], + ) + + bits = tf.cast(self.severity_factor() * 3, tf.int32) + shift = tf.cast(4 - bits + 1, tf.uint8) + image = tf.cast(image, tf.uint8) + image = tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + image = tf.cast(image, self.compute_dtype) + return preprocessing.transform_value_range( + images=image, + original_range=[0, 255], + target_range=self.value_range, + ) + + def _rotate(self, image): + angle = tf.expand_dims(tf.cast(self.severity_factor() * 30, tf.float32), axis=0) + shape = tf.cast(tf.shape(image), tf.float32) + + return preprocessing.transform( + tf.expand_dims(image, 0), + preprocessing.get_rotation_matrix(angle, shape[0], shape[1]), + )[0] + + def _solarize(self, image): + threshold = tf.cast(tf.cast(self.severity_factor() * 255, tf.int32), tf.float32) + + image = preprocessing.transform_value_range( + image, original_range=self.value_range, target_range=(0, 255) + ) + result = tf.clip_by_value(image, 0, 255) + result = tf.where(result < threshold, result, 255 - result) + return preprocessing.transform_value_range( + result, original_range=(0, 255), target_range=self.value_range + ) + + def _shear_x(self, image): + x = tf.cast(self.severity_factor() * 0.3, tf.float32) + x *= preprocessing.random_inversion(self._random_generator) + transform_x = layers.RandomShear._format_transform( + [1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] + ) + return preprocessing.transform( + images=tf.expand_dims(image, 0), transforms=transform_x + )[0] + + def _shear_y(self, image): + y = tf.cast(self.severity_factor() * 0.3, tf.float32) + y *= preprocessing.random_inversion(self._random_generator) + transform_x = layers.RandomShear._format_transform( + [1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0] + ) + return preprocessing.transform( + images=tf.expand_dims(image, 0), transforms=transform_x + )[0] + + def _translate_x(self, image): + shape = tf.cast(tf.shape(image), tf.float32) + x = tf.cast(self.severity_factor() * shape[1] / 3, tf.float32) + x = tf.expand_dims(tf.expand_dims(x, axis=0), axis=0) + x *= preprocessing.random_inversion(self._random_generator) + x = tf.cast(x, tf.int32) + + translations = tf.cast( + tf.concat([x, tf.zeros_like(x)], axis=1), dtype=tf.float32 + ) + return preprocessing.transform( + tf.expand_dims(image, 0), preprocessing.get_translation_matrix(translations) + )[0] + + def _translate_y(self, image): + shape = tf.cast(tf.shape(image), tf.float32) + y = tf.cast(self.severity_factor() * shape[0] / 3, tf.float32) + y = tf.expand_dims(tf.expand_dims(y, axis=0), axis=0) + y *= preprocessing.random_inversion(self._random_generator) + y = tf.cast(y, tf.int32) + + translations = tf.cast( + tf.concat([tf.zeros_like(y), y], axis=1), dtype=tf.float32 + ) + return preprocessing.transform( + tf.expand_dims(image, 0), preprocessing.get_translation_matrix(translations) + )[0] + + def _apply_op(self, image, op_index): + augmented = image + augmented = tf.cond( + op_index == tf.constant([0], dtype=tf.int32), + lambda: self._auto_contrast(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([1], dtype=tf.int32), + lambda: self._equalize(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([2], dtype=tf.int32), + lambda: self._posterize(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([3], dtype=tf.int32), + lambda: self._rotate(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([4], dtype=tf.int32), + lambda: self._solarize(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([5], dtype=tf.int32), + lambda: self._shear_x(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([6], dtype=tf.int32), + lambda: self._shear_y(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([7], dtype=tf.int32), + lambda: self._translate_x(augmented), + lambda: augmented, + ) + augmented = tf.cond( + op_index == tf.constant([8], dtype=tf.int32), + lambda: self._translate_y(augmented), + lambda: augmented, + ) + return augmented + + def augment_image(self, image, transformation=None): + chain_mixing_weights = AugMix._sample_from_dirichlet( + tf.ones([self.num_chains]) * self.alpha + ) + weight_sample = AugMix._sample_from_beta(self.alpha, self.alpha) + + result = tf.zeros_like(image) + curr_chain = tf.constant([0], dtype=tf.int32) + + image, chain_mixing_weights, curr_chain, result = tf.while_loop( + lambda image, chain_mixing_weights, curr_chain, result: tf.less( + curr_chain, self.num_chains + ), + self._loop_on_width, + [image, chain_mixing_weights, curr_chain, result], + ) + result = weight_sample * image + (1 - weight_sample) * result + return result + + def augment_label(self, label, transformation=None): + return label + + def get_config(self): + config = { + "value_range": self.value_range, + "severity": self.severity, + "num_chains": self.num_chains, + "chain_depth": self.chain_depth, + "alpha": self.alpha, + "seed": self.seed, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_cv/layers/preprocessing/aug_mix_test.py b/keras_cv/layers/preprocessing/aug_mix_test.py new file mode 100644 index 0000000000..bbeda192c1 --- /dev/null +++ b/keras_cv/layers/preprocessing/aug_mix_test.py @@ -0,0 +1,79 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv.layers import preprocessing + + +class AugMixTest(tf.test.TestCase): + def test_return_shapes(self): + layer = preprocessing.AugMix([0, 255]) + + # RGB + xs = tf.ones((2, 512, 512, 3)) + xs = layer(xs) + self.assertEqual(xs.shape, [2, 512, 512, 3]) + + # greyscale + xs = tf.ones((2, 512, 512, 1)) + xs = layer(xs) + self.assertEqual(xs.shape, [2, 512, 512, 1]) + + def test_in_single_image(self): + layer = preprocessing.AugMix([0, 255]) + + # RGB + xs = tf.cast( + tf.ones((512, 512, 3)), + dtype=tf.float32, + ) + + xs = layer(xs) + self.assertEqual(xs.shape, [512, 512, 3]) + + # greyscale + xs = tf.cast( + tf.ones((512, 512, 1)), + dtype=tf.float32, + ) + + xs = layer(xs) + self.assertEqual(xs.shape, [512, 512, 1]) + + def test_non_square_images(self): + layer = preprocessing.AugMix([0, 255]) + + # RGB + xs = tf.ones((2, 256, 512, 3)) + xs = layer(xs) + self.assertEqual(xs.shape, [2, 256, 512, 3]) + + # greyscale + xs = tf.ones((2, 256, 512, 1)) + xs = layer(xs) + self.assertEqual(xs.shape, [2, 256, 512, 1]) + + def test_single_input_args(self): + layer = preprocessing.AugMix([0, 255]) + + # RGB + xs = tf.ones((2, 512, 512, 3)) + xs = layer(xs) + self.assertEqual(xs.shape, [2, 512, 512, 3]) + + # greyscale + xs = tf.ones((2, 512, 512, 1)) + xs = layer(xs) + self.assertEqual(xs.shape, [2, 512, 512, 1]) diff --git a/keras_cv/layers/preprocessing/fourier_mix.py b/keras_cv/layers/preprocessing/fourier_mix.py index d64611e897..38b2e2842e 100644 --- a/keras_cv/layers/preprocessing/fourier_mix.py +++ b/keras_cv/layers/preprocessing/fourier_mix.py @@ -36,7 +36,7 @@ class FourierMix(BaseImageAugmentationLayer): Sample usage: ```python (images, labels), _ = tf.keras.datasets.cifar10.load_data() - fourier_mix = keras_cv.layers.preprocessing.mix_up.FourierMix(0.5) + fourier_mix = keras_cv.layers.preprocessing.FourierMix(0.5) augmented_images, updated_labels = fourier_mix({'images': images, 'labels': labels}) # output == {'images': updated_images, 'labels': updated_labels} ``` diff --git a/keras_cv/layers/preprocessing/mix_up.py b/keras_cv/layers/preprocessing/mix_up.py index 451f5b465c..f4314e8f69 100644 --- a/keras_cv/layers/preprocessing/mix_up.py +++ b/keras_cv/layers/preprocessing/mix_up.py @@ -36,7 +36,7 @@ class MixUp(BaseImageAugmentationLayer): Sample usage: ```python (images, labels), _ = tf.keras.datasets.cifar10.load_data() - mixup = keras_cv.layers.preprocessing.mix_up.MixUp(10) + mixup = keras_cv.layers.preprocessing.MixUp(10) augmented_images, updated_labels = mixup({'images': images, 'labels': labels}) # output == {'images': updated_images, 'labels': updated_labels} ``` diff --git a/keras_cv/layers/serialization_test.py b/keras_cv/layers/serialization_test.py index d2221d514a..0bc9b294d5 100644 --- a/keras_cv/layers/serialization_test.py +++ b/keras_cv/layers/serialization_test.py @@ -136,6 +136,18 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase): "seed": 1234, }, ), + ( + "AugMix", + preprocessing.AugMix, + { + "value_range": (0, 255), + "severity": 0.3, + "num_chains": 3, + "chain_depth": -1, + "alpha": 1.0, + "seed": 1, + }, + ), ) def test_layer_serialization(self, layer_cls, init_args): layer = layer_cls(**init_args) diff --git a/keras_cv/utils/preprocessing.py b/keras_cv/utils/preprocessing.py index cca8346313..13a8b30cd9 100644 --- a/keras_cv/utils/preprocessing.py +++ b/keras_cv/utils/preprocessing.py @@ -137,6 +137,79 @@ def random_inversion(random_generator): return negate +def get_rotation_matrix(angles, image_height, image_width, name=None): + """Returns projective transform(s) for the given angle(s). + Args: + angles: A scalar angle to rotate all images by, or (for batches of images) a + vector with an angle to rotate each image in the batch. The rank must be + statically known (the shape is not `TensorShape(None)`). + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + name: The name of the op. + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be given + to operation `image_projective_transform_v2`. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. + """ + with backend.name_scope(name or "rotation_matrix"): + x_offset = ( + (image_width - 1) + - (tf.cos(angles) * (image_width - 1) - tf.sin(angles) * (image_height - 1)) + ) / 2.0 + y_offset = ( + (image_height - 1) + - (tf.sin(angles) * (image_width - 1) + tf.cos(angles) * (image_height - 1)) + ) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + tf.cos(angles)[:, None], + -tf.sin(angles)[:, None], + x_offset[:, None], + tf.sin(angles)[:, None], + tf.cos(angles)[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.float32), + ], + axis=1, + ) + + +def get_translation_matrix(translations, name=None): + """Returns projective transform(s) for the given translation(s). + Args: + translations: A matrix of 2-element lists representing `[dx, dy]` + to translate for each image (for a batch of images). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be given + to `transform`. + """ + with backend.name_scope(name or "translation_matrix"): + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.float32), + tf.zeros((num_translations, 1), tf.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.float32), + tf.ones((num_translations, 1), tf.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.float32), + ], + axis=1, + ) + + def transform( images, transforms,