forked from keras-team/keras-cv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding augmix * add augmix with demo * add augmix test * add augmix * Update aug_mix.py * Update aug_mix.py * Update aug_mix_demo.py * updating augmix with changes * init commit with review comments * fix linting issues * removing RandomShear assignment in constructor * review comments part 2 * review comments part 2 fix * review comments part 3 * change variable names * fix lint issues * edit serialization test * edit demo to use utils * Update augmix Co-authored-by: Luke Wood <[email protected]>
- Loading branch information
1 parent
3a2cd12
commit 2736233
Showing
9 changed files
with
516 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""aug_mix_demo.py shows how to use the AugMix preprocessing layer. | ||
Operates on the oxford_flowers102 dataset. In this script the flowers | ||
are loaded, then are passed through the preprocessing layers. | ||
Finally, they are shown using matplotlib. | ||
""" | ||
|
||
import demo_utils | ||
import tensorflow as tf | ||
|
||
from keras_cv import layers | ||
|
||
|
||
def main(): | ||
augmix = layers.AugMix(value_range=[0, 255]) | ||
ds = demo_utils.load_oxford_dataset() | ||
ds = ds.map(augmix, num_parallel_calls=tf.data.AUTOTUNE) | ||
demo_utils.visualize_dataset(ds) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import tensorflow as tf | ||
|
||
from keras_cv import layers | ||
from keras_cv.utils import preprocessing | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="keras_cv") | ||
class AugMix(tf.keras.__internal__.layers.BaseImageAugmentationLayer): | ||
"""Performs the AugMix data augmentation technique. | ||
AugMix aims to produce images with variety while preserving the | ||
image semantics and local statistics. During the augmentation process, each image | ||
is augmented `num_chains` different ways, each way consisting of `chain_depth` | ||
augmentations. Augmentations are sampled from the list: translation, shearing, | ||
rotation, posterization, histogram equalization, solarization and auto contrast. | ||
The results of each chain are then mixed together with the original | ||
image based on random samples from a Dirichlet distribution. | ||
Args: | ||
value_range: the range of values the incoming images will have. | ||
Represented as a two number tuple written (low, high). | ||
This is typically either `(0, 1)` or `(0, 255)` depending | ||
on how your preprocessing pipeline is setup. | ||
severity: A tuple of two floats, a single float or a `keras_cv.FactorSampler`. | ||
A value is sampled from the provided range. If a float is passed, the | ||
range is interpreted as `(0, severity)`. This value represents the | ||
level of strength of augmentations and is in the range [0, 1]. | ||
Defaults to 0.3. | ||
num_chains: an integer representing the number of different chains to | ||
be mixed. Defaults to 3. | ||
chain_depth: an integer or range representing the number of transformations in | ||
the chains. Defaults to [1,3]. | ||
alpha: a float value used as the probability coefficients for the | ||
Beta and Dirichlet distributions. Defaults to 1.0. | ||
seed: Integer. Used to create a random seed. | ||
References: | ||
[AugMix paper](https://arxiv.org/pdf/1912.02781) | ||
[Official Code](https://github.com/google-research/augmix) | ||
[Unoffial TF Code](https://github.com/szacho/augmix-tf) | ||
Sample Usage: | ||
```python | ||
(images, labels), _ = tf.keras.datasets.cifar10.load_data() | ||
augmix = keras_cv.layers.AugMix([0, 255]) | ||
augmented_images = augmix(images[:100]) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
value_range, | ||
severity=0.3, | ||
num_chains=3, | ||
chain_depth=[1, 3], | ||
alpha=1.0, | ||
seed=None, | ||
**kwargs, | ||
): | ||
super().__init__(seed=seed, **kwargs) | ||
self.value_range = value_range | ||
self.num_chains = num_chains | ||
self.chain_depth = chain_depth | ||
|
||
if isinstance(self.chain_depth, int): | ||
self.chain_depth = [self.chain_depth, self.chain_depth] | ||
|
||
self.alpha = alpha | ||
self.seed = seed | ||
self.auto_vectorize = False | ||
self.severity = severity | ||
self.severity_factor = preprocessing.parse_factor( | ||
self.severity, | ||
min_value=0.01, | ||
max_value=1.0, | ||
param_name="severity", | ||
seed=self.seed, | ||
) | ||
|
||
# initialize layers | ||
self.auto_contrast = layers.AutoContrast(value_range=self.value_range) | ||
self.equalize = layers.Equalization(value_range=self.value_range) | ||
|
||
@staticmethod | ||
def _sample_from_dirichlet(alpha): | ||
gamma_sample = tf.random.gamma(shape=(), alpha=alpha) | ||
return gamma_sample / tf.reduce_sum(gamma_sample, axis=-1, keepdims=True) | ||
|
||
@staticmethod | ||
def _sample_from_beta(alpha, beta): | ||
sample_alpha = tf.random.gamma((), 1.0, beta=alpha) | ||
sample_beta = tf.random.gamma((), 1.0, beta=beta) | ||
return sample_alpha / (sample_alpha + sample_beta) | ||
|
||
def _sample_depth(self): | ||
return self._random_generator.random_uniform( | ||
shape=(), | ||
minval=self.chain_depth[0], | ||
maxval=self.chain_depth[1] + 1, | ||
dtype=tf.int32, | ||
) | ||
|
||
def _loop_on_depth(self, depth_level, image_aug): | ||
op_index = self._random_generator.random_uniform( | ||
shape=(), minval=0, maxval=8, dtype=tf.int32 | ||
) | ||
image_aug = self._apply_op(image_aug, op_index) | ||
depth_level += 1 | ||
return depth_level, image_aug | ||
|
||
def _loop_on_width(self, image, chain_mixing_weights, curr_chain, result): | ||
image_aug = tf.identity(image) | ||
chain_depth = self._sample_depth() | ||
|
||
depth_level = tf.constant([0], dtype=tf.int32) | ||
depth_level, image_aug = tf.while_loop( | ||
lambda depth_level, image_aug: tf.less(depth_level, chain_depth), | ||
self._loop_on_depth, | ||
[depth_level, image_aug], | ||
) | ||
result += tf.gather(chain_mixing_weights, curr_chain) * image_aug | ||
curr_chain += 1 | ||
return image, chain_mixing_weights, curr_chain, result | ||
|
||
def _auto_contrast(self, image): | ||
return self.auto_contrast(image) | ||
|
||
def _equalize(self, image): | ||
return self.equalize(image) | ||
|
||
def _posterize(self, image): | ||
image = preprocessing.transform_value_range( | ||
images=image, | ||
original_range=self.value_range, | ||
target_range=[0, 255], | ||
) | ||
|
||
bits = tf.cast(self.severity_factor() * 3, tf.int32) | ||
shift = tf.cast(4 - bits + 1, tf.uint8) | ||
image = tf.cast(image, tf.uint8) | ||
image = tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) | ||
image = tf.cast(image, self.compute_dtype) | ||
return preprocessing.transform_value_range( | ||
images=image, | ||
original_range=[0, 255], | ||
target_range=self.value_range, | ||
) | ||
|
||
def _rotate(self, image): | ||
angle = tf.expand_dims(tf.cast(self.severity_factor() * 30, tf.float32), axis=0) | ||
shape = tf.cast(tf.shape(image), tf.float32) | ||
|
||
return preprocessing.transform( | ||
tf.expand_dims(image, 0), | ||
preprocessing.get_rotation_matrix(angle, shape[0], shape[1]), | ||
)[0] | ||
|
||
def _solarize(self, image): | ||
threshold = tf.cast(tf.cast(self.severity_factor() * 255, tf.int32), tf.float32) | ||
|
||
image = preprocessing.transform_value_range( | ||
image, original_range=self.value_range, target_range=(0, 255) | ||
) | ||
result = tf.clip_by_value(image, 0, 255) | ||
result = tf.where(result < threshold, result, 255 - result) | ||
return preprocessing.transform_value_range( | ||
result, original_range=(0, 255), target_range=self.value_range | ||
) | ||
|
||
def _shear_x(self, image): | ||
x = tf.cast(self.severity_factor() * 0.3, tf.float32) | ||
x *= preprocessing.random_inversion(self._random_generator) | ||
transform_x = layers.RandomShear._format_transform( | ||
[1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] | ||
) | ||
return preprocessing.transform( | ||
images=tf.expand_dims(image, 0), transforms=transform_x | ||
)[0] | ||
|
||
def _shear_y(self, image): | ||
y = tf.cast(self.severity_factor() * 0.3, tf.float32) | ||
y *= preprocessing.random_inversion(self._random_generator) | ||
transform_x = layers.RandomShear._format_transform( | ||
[1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0] | ||
) | ||
return preprocessing.transform( | ||
images=tf.expand_dims(image, 0), transforms=transform_x | ||
)[0] | ||
|
||
def _translate_x(self, image): | ||
shape = tf.cast(tf.shape(image), tf.float32) | ||
x = tf.cast(self.severity_factor() * shape[1] / 3, tf.float32) | ||
x = tf.expand_dims(tf.expand_dims(x, axis=0), axis=0) | ||
x *= preprocessing.random_inversion(self._random_generator) | ||
x = tf.cast(x, tf.int32) | ||
|
||
translations = tf.cast( | ||
tf.concat([x, tf.zeros_like(x)], axis=1), dtype=tf.float32 | ||
) | ||
return preprocessing.transform( | ||
tf.expand_dims(image, 0), preprocessing.get_translation_matrix(translations) | ||
)[0] | ||
|
||
def _translate_y(self, image): | ||
shape = tf.cast(tf.shape(image), tf.float32) | ||
y = tf.cast(self.severity_factor() * shape[0] / 3, tf.float32) | ||
y = tf.expand_dims(tf.expand_dims(y, axis=0), axis=0) | ||
y *= preprocessing.random_inversion(self._random_generator) | ||
y = tf.cast(y, tf.int32) | ||
|
||
translations = tf.cast( | ||
tf.concat([tf.zeros_like(y), y], axis=1), dtype=tf.float32 | ||
) | ||
return preprocessing.transform( | ||
tf.expand_dims(image, 0), preprocessing.get_translation_matrix(translations) | ||
)[0] | ||
|
||
def _apply_op(self, image, op_index): | ||
augmented = image | ||
augmented = tf.cond( | ||
op_index == tf.constant([0], dtype=tf.int32), | ||
lambda: self._auto_contrast(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([1], dtype=tf.int32), | ||
lambda: self._equalize(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([2], dtype=tf.int32), | ||
lambda: self._posterize(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([3], dtype=tf.int32), | ||
lambda: self._rotate(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([4], dtype=tf.int32), | ||
lambda: self._solarize(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([5], dtype=tf.int32), | ||
lambda: self._shear_x(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([6], dtype=tf.int32), | ||
lambda: self._shear_y(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([7], dtype=tf.int32), | ||
lambda: self._translate_x(augmented), | ||
lambda: augmented, | ||
) | ||
augmented = tf.cond( | ||
op_index == tf.constant([8], dtype=tf.int32), | ||
lambda: self._translate_y(augmented), | ||
lambda: augmented, | ||
) | ||
return augmented | ||
|
||
def augment_image(self, image, transformation=None): | ||
chain_mixing_weights = AugMix._sample_from_dirichlet( | ||
tf.ones([self.num_chains]) * self.alpha | ||
) | ||
weight_sample = AugMix._sample_from_beta(self.alpha, self.alpha) | ||
|
||
result = tf.zeros_like(image) | ||
curr_chain = tf.constant([0], dtype=tf.int32) | ||
|
||
image, chain_mixing_weights, curr_chain, result = tf.while_loop( | ||
lambda image, chain_mixing_weights, curr_chain, result: tf.less( | ||
curr_chain, self.num_chains | ||
), | ||
self._loop_on_width, | ||
[image, chain_mixing_weights, curr_chain, result], | ||
) | ||
result = weight_sample * image + (1 - weight_sample) * result | ||
return result | ||
|
||
def augment_label(self, label, transformation=None): | ||
return label | ||
|
||
def get_config(self): | ||
config = { | ||
"value_range": self.value_range, | ||
"severity": self.severity, | ||
"num_chains": self.num_chains, | ||
"chain_depth": self.chain_depth, | ||
"alpha": self.alpha, | ||
"seed": self.seed, | ||
} | ||
base_config = super().get_config() | ||
return dict(list(base_config.items()) + list(config.items())) |
Oops, something went wrong.