diff --git a/keras_cv/layers/preprocessing/__init__.py b/keras_cv/layers/preprocessing/__init__.py index 4eb2b5868d..1985336677 100644 --- a/keras_cv/layers/preprocessing/__init__.py +++ b/keras_cv/layers/preprocessing/__init__.py @@ -36,6 +36,7 @@ from keras_cv.layers.preprocessing.fourier_mix import FourierMix from keras_cv.layers.preprocessing.grayscale import Grayscale from keras_cv.layers.preprocessing.grid_mask import GridMask +from keras_cv.layers.preprocessing.maybe_apply import MaybeApply from keras_cv.layers.preprocessing.mix_up import MixUp from keras_cv.layers.preprocessing.posterization import Posterization from keras_cv.layers.preprocessing.rand_augment import RandAugment diff --git a/keras_cv/layers/preprocessing/maybe_apply.py b/keras_cv/layers/preprocessing/maybe_apply.py new file mode 100644 index 0000000000..2494f97402 --- /dev/null +++ b/keras_cv/layers/preprocessing/maybe_apply.py @@ -0,0 +1,108 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf + + +@tf.keras.utils.register_keras_serializable(package="keras_cv") +class MaybeApply(tf.keras.__internal__.layers.BaseImageAugmentationLayer): + """Apply provided layer to random elements in a batch. + + Args: + layer: a keras `Layer` or `BaseImageAugmentationLayer`. This layer will be + applied to randomly chosen samples in a batch. Layer should not modify the + size of provided inputs. + rate: controls the frequency of applying the layer. 1.0 means all elements in + a batch will be modified. 0.0 means no elements will be modified. + Defaults to 0.5. + auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for + batched input. Setting this to True might give better performance but + currently doesn't work with XLA. Defaults to False. + seed: integer, controls random behaviour. + + Example usage: + ``` + # Let's declare an example layer that will set all image pixels to zero. + zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]}) + + # Create a small batch of random, single-channel, 2x2 images: + images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1]) + print(images[..., 0]) + # + + # Apply the layer with 50% probability: + maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234) + outputs = maybe_apply(images) + print(outputs[..., 0]) + # + + # We can observe that the layer has been randomly applied to 2 out of 5 samples. + ``` + """ + + def __init__(self, layer, rate=0.5, auto_vectorize=False, seed=None, **kwargs): + super().__init__(seed=seed, **kwargs) + + if not (0 <= rate <= 1.0): + raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}") + + self._layer = layer + self._rate = rate + self.auto_vectorize = auto_vectorize + self.seed = seed + + def _augment(self, inputs): + if self._random_generator.random_uniform(shape=()) > 1.0 - self._rate: + return self._layer(inputs) + else: + return inputs + + def get_config(self): + config = super().get_config() + config.update( + { + "rate": self._rate, + "layer": self._layer, + "seed": self.seed, + "auto_vectorize": self.auto_vectorize, + } + ) + return config diff --git a/keras_cv/layers/preprocessing/maybe_apply_test.py b/keras_cv/layers/preprocessing/maybe_apply_test.py new file mode 100644 index 0000000000..e1bce47b82 --- /dev/null +++ b/keras_cv/layers/preprocessing/maybe_apply_test.py @@ -0,0 +1,125 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf +from absl.testing import parameterized + +from keras_cv.layers.preprocessing.maybe_apply import MaybeApply + + +class ZeroOut(tf.keras.__internal__.layers.BaseImageAugmentationLayer): + """Zero out all entries, for testing purposes.""" + + def __init__(self): + super(ZeroOut, self).__init__() + + def augment_image(self, image, transformation=None): + return 0 * image + + def augment_label(self, label, transformation=None): + return 0 * label + + def augment_bounding_box(self, bounding_box, transformation=None): + return 0 * bounding_box + + +class MaybeApplyTest(tf.test.TestCase, parameterized.TestCase): + rng = tf.random.Generator.from_seed(seed=1234) + + @parameterized.parameters([-0.5, 1.7]) + def test_raises_error_on_invalid_rate_parameter(self, invalid_rate): + with self.assertRaises(ValueError): + MaybeApply(rate=invalid_rate, layer=ZeroOut()) + + def test_works_with_batched_input(self): + batch_size = 32 + dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3)) + layer = MaybeApply(rate=0.5, layer=ZeroOut(), seed=1234) + + outputs = layer(dummy_inputs) + num_zero_inputs = self._num_zero_batches(dummy_inputs) + num_zero_outputs = self._num_zero_batches(outputs) + + self.assertEqual(num_zero_inputs, 0) + self.assertLess(num_zero_outputs, batch_size) + self.assertGreater(num_zero_outputs, 0) + + @staticmethod + def _num_zero_batches(images): + num_batches = tf.shape(images)[0] + num_non_zero_batches = tf.math.count_nonzero( + tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32 + ) + return num_batches - num_non_zero_batches + + def test_inputs_unchanged_with_zero_rate(self): + dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) + layer = MaybeApply(rate=0.0, layer=ZeroOut()) + + outputs = layer(dummy_inputs) + + self.assertAllClose(outputs, dummy_inputs) + + def test_all_inputs_changed_with_rate_equal_to_one(self): + dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) + layer = MaybeApply(rate=1.0, layer=ZeroOut()) + + outputs = layer(dummy_inputs) + + self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) + + def test_works_with_single_image(self): + dummy_inputs = self.rng.uniform(shape=(224, 224, 3)) + layer = MaybeApply(rate=1.0, layer=ZeroOut()) + + outputs = layer(dummy_inputs) + + self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) + + def test_can_modify_label(self): + dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) + dummy_labels = tf.ones(shape=(32, 2)) + layer = MaybeApply(rate=1.0, layer=ZeroOut()) + + outputs = layer({"images": dummy_inputs, "labels": dummy_labels}) + + self.assertAllEqual(outputs["labels"], tf.zeros_like(dummy_labels)) + + def test_can_modify_bounding_box(self): + dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) + dummy_boxes = tf.ones(shape=(32, 4)) + layer = MaybeApply(rate=1.0, layer=ZeroOut()) + + outputs = layer({"images": dummy_inputs, "bounding_boxes": dummy_boxes}) + + self.assertAllEqual(outputs["bounding_boxes"], tf.zeros_like(dummy_boxes)) + + def test_works_with_native_keras_layers(self): + dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) + zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]}) + layer = MaybeApply(rate=1.0, layer=zero_out) + + outputs = layer(dummy_inputs) + + self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) + + def test_works_with_xla(self): + dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) + # auto_vectorize=True will crash XLA + layer = MaybeApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False) + + @tf.function(jit_compile=True) + def apply(x): + return layer(x) + + apply(dummy_inputs) diff --git a/keras_cv/layers/serialization_test.py b/keras_cv/layers/serialization_test.py index b14f122ba8..c8fafbb7ce 100644 --- a/keras_cv/layers/serialization_test.py +++ b/keras_cv/layers/serialization_test.py @@ -134,6 +134,15 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase): "seed": 1234, }, ), + ( + "MaybeApply", + preprocessing.MaybeApply, + { + "rate": 0.5, + "layer": None, + "seed": 1234, + }, + ), ) def test_layer_serialization(self, layer_cls, init_args): layer = layer_cls(**init_args)