Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces MaybeApply layer. #435

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from keras_cv.layers.preprocessing.fourier_mix import FourierMix
from keras_cv.layers.preprocessing.grayscale import Grayscale
from keras_cv.layers.preprocessing.grid_mask import GridMask
from keras_cv.layers.preprocessing.maybe_apply import MaybeApply
from keras_cv.layers.preprocessing.mix_up import MixUp
from keras_cv.layers.preprocessing.posterization import Posterization
from keras_cv.layers.preprocessing.rand_augment import RandAugment
Expand Down
105 changes: 105 additions & 0 deletions keras_cv/layers/preprocessing/maybe_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class MaybeApply(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
"""Apply provided layer to random elements in a batch.

Args:
layer: a keras Layer or BaseImageAugmentationLayer. This layer will be applied
LukeWood marked this conversation as resolved.
Show resolved Hide resolved
to randomly chosen samples in a batch.
rate: controls the frequency of applying the layer. 1.0 means all elements in
a batch will be modified. 0.0 means no elements will be modified.
Defaults to 0.5.
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for
batched input. Setting this to True might give better performance but
currently doesn't work with XLA. Defaults to False.
seed: integer, controls random behaviour.

Example usage:
LukeWood marked this conversation as resolved.
Show resolved Hide resolved
# Let's declare an example layer that will set all image pixels to zero.
zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]})

# Create a small batch of random, single-channel, 2x2 images:
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])
print(images[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0.08216608, 0.40928006],
# [0.39318466, 0.3162533 ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0.51828027, 0.12730157],
# [0.288486 , 0.252975 ]]], dtype=float32)>

# Apply the layer with 50% probability:
maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234)
outputs = maybe_apply(images)
print(outputs[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0. , 0. ],
# [0. , 0. ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0. , 0. ],
# [0. , 0. ]]], dtype=float32)>

# We can observe that the layer has been randomly applied to 2 out of 5 batches.
"""

def __init__(self, layer, rate=0.5, auto_vectorize=False, seed=None, **kwargs):
super().__init__(seed=seed, **kwargs)

if not (0 <= rate <= 1.0):
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}")

self._layer = layer
self._rate = rate
self.auto_vectorize = auto_vectorize
self.seed = seed

def _augment(self, inputs):
if self._random_generator.random_uniform(shape=()) > 1.0 - self._rate:
return self._layer(inputs)
else:
return inputs

def get_config(self):
config = super().get_config()
config.update(
{
"rate": self._rate,
"layer": self._layer,
"seed": self.seed,
"auto_vectorize": self.auto_vectorize,
}
)
return config
125 changes: 125 additions & 0 deletions keras_cv/layers/preprocessing/maybe_apply_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl.testing import parameterized

from keras_cv.layers.preprocessing.maybe_apply import MaybeApply


class ZeroOut(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
"""Zero out all entries, for testing purposes."""

def __init__(self):
super(ZeroOut, self).__init__()

def augment_image(self, image, transformation=None):
return 0 * image

def augment_label(self, label, transformation=None):
return 0 * label

def augment_bounding_box(self, bounding_box, transformation=None):
return 0 * bounding_box


class MaybeApplyTest(tf.test.TestCase, parameterized.TestCase):
rng = tf.random.Generator.from_seed(seed=1234)

@parameterized.parameters([-0.5, 1.7])
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate):
with self.assertRaises(ValueError):
MaybeApply(rate=invalid_rate, layer=ZeroOut())

def test_works_with_batched_input(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you pass a seed so that this test is not potentially flaky? Given, it is 1/2^32 flakiness, but still may as well seed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added seed to rng on line 37.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! does this seed the layer too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Added seed param to layer as well.

batch_size = 32
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
layer = MaybeApply(rate=0.5, layer=ZeroOut(), seed=1234)

outputs = layer(dummy_inputs)
num_zero_inputs = self._num_zero_batches(dummy_inputs)
num_zero_outputs = self._num_zero_batches(outputs)

self.assertEqual(num_zero_inputs, 0)
self.assertLess(num_zero_outputs, batch_size)
self.assertGreater(num_zero_outputs, 0)

@staticmethod
def _num_zero_batches(images):
num_batches = tf.shape(images)[0]
num_non_zero_batches = tf.math.count_nonzero(
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32
)
return num_batches - num_non_zero_batches

def test_inputs_unchanged_with_zero_rate(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = MaybeApply(rate=0.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllClose(outputs, dummy_inputs)

def test_all_inputs_changed_with_rate_equal_to_one(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))

def test_works_with_single_image(self):
dummy_inputs = self.rng.uniform(shape=(224, 224, 3))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))

def test_can_modify_label(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
dummy_labels = tf.ones(shape=(32, 2))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer({"images": dummy_inputs, "labels": dummy_labels})

self.assertAllEqual(outputs["labels"], tf.zeros_like(dummy_labels))

def test_can_modify_bounding_box(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
dummy_boxes = tf.ones(shape=(32, 4))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer({"images": dummy_inputs, "bounding_boxes": dummy_boxes})

self.assertAllEqual(outputs["bounding_boxes"], tf.zeros_like(dummy_boxes))

def test_works_with_native_keras_layers(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]})
layer = MaybeApply(rate=1.0, layer=zero_out)

outputs = layer(dummy_inputs)

self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))

def test_works_with_xla(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
# auto_vectorize=True will crash XLA
layer = MaybeApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False)

@tf.function(jit_compile=True)
def apply(x):
return layer(x)

apply(dummy_inputs)
9 changes: 9 additions & 0 deletions keras_cv/layers/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
"seed": 1234,
},
),
(
"MaybeApply",
preprocessing.MaybeApply,
{
"rate": 0.5,
"layer": None,
"seed": 1234,
},
),
)
def test_layer_serialization(self, layer_cls, init_args):
layer = layer_cls(**init_args)
Expand Down