diff --git a/keras_cv/layers/preprocessing/auto_contrast.py b/keras_cv/layers/preprocessing/auto_contrast.py index def8f211f2..93d0c38f79 100644 --- a/keras_cv/layers/preprocessing/auto_contrast.py +++ b/keras_cv/layers/preprocessing/auto_contrast.py @@ -14,11 +14,14 @@ import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class AutoContrast(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class AutoContrast(BaseImageAugmentationLayer): """Performs the AutoContrast operation on an image. Auto contrast stretches the values of an image across the entire available diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py new file mode 100644 index 0000000000..847672bc7a --- /dev/null +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -0,0 +1,288 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow.tools.docs import doc_controls + +from keras_cv.layers.preprocessing import preprocessing_utils as utils + +H_AXIS = -3 +W_AXIS = -2 + +IMAGES = "images" +LABELS = "labels" +TARGETS = "targets" +BOUNDING_BOXES = "bounding_boxes" + + +@tf.keras.utils.register_keras_serializable(package="keras_cv") +class BaseImageAugmentationLayer(tf.keras.__internal__.layers.BaseRandomLayer): + """Abstract base layer for image augmentaion. + + This layer contains base functionalities for preprocessing layers which + augment image related data, eg. image and in future, label and bounding + boxes. The subclasses could avoid making certain mistakes and reduce code + duplications. + + This layer requires you to implement one method: `augment_image()`, which + augments one single image during the training. There are a few additional + methods that you can implement for added functionality on the layer: + + `augment_label()`, which handles label augmentation if the layer supports + that. + + `augment_bounding_boxes()`, which handles the bounding box augmentation, if + the layer supports that. + + `get_random_transformation()`, which should produce a random transformation + setting. The tranformation object, which could be any type, will be passed + to `augment_image`, `augment_label` and `augment_bounding_boxes`, to + coodinate the randomness behavior, eg, in the RandomFlip layer, the image + and bounding_boxes should be changed in the same way. + + The `call()` method support two formats of inputs: + 1. Single image tensor with 3D (HWC) or 4D (NHWC) format. + 2. A dict of tensors with stable keys. The supported keys are: + `"images"`, `"labels"` and `"bounding_boxes"` at the moment. We might add + more keys in future when we support more types of augmentation. + + The output of the `call()` will be in two formats, which will be the same + structure as the inputs. + + The `call()` will handle the logic detecting the training/inference mode, + unpack the inputs, forward to the correct function, and pack the output back + to the same structure as the inputs. + + By default the `call()` method leverages the `tf.vectorized_map()` function. + Auto-vectorization can be disabled by setting `self.auto_vectorize = False` + in your `__init__()` method. When disabled, `call()` instead relies + on `tf.map_fn()`. For example: + + ```python + class SubclassLayer(keras_cv.layers.BaseImageAugmentationLayer): + def __init__(self): + super().__init__() + self.auto_vectorize = False + ``` + + Example: + + ```python + class RandomContrast(keras_cv.layers.BaseImageAugmentationLayer): + + def __init__(self, factor=(0.5, 1.5), **kwargs): + super().__init__(**kwargs) + self._factor = factor + + def augment_image(self, image, transformation): + random_factor = tf.random.uniform([], self._factor[0], self._factor[1]) + mean = tf.math.reduced_mean(inputs, axis=-1, keep_dim=True) + return (inputs - mean) * random_factor + mean + ``` + + Note that since the randomness is also a common functionnality, this layer + also includes a tf.keras.backend.RandomGenerator, which can be used to + produce the random numbers. The random number generator is stored in the + `self._random_generator` attribute. + """ + + def __init__(self, seed=None, **kwargs): + super().__init__(seed=seed, **kwargs) + + @property + def auto_vectorize(self): + """Control whether automatic vectorization occurs. + + By default the `call()` method leverages the `tf.vectorized_map()` + function. Auto-vectorization can be disabled by setting + `self.auto_vectorize = False` in your `__init__()` method. When + disabled, `call()` instead relies on `tf.map_fn()`. For example: + + ```python + class SubclassLayer(BaseImageAugmentationLayer): + def __init__(self): + super().__init__() + self.auto_vectorize = False + ``` + """ + return getattr(self, "_auto_vectorize", True) + + @auto_vectorize.setter + def auto_vectorize(self, auto_vectorize): + self._auto_vectorize = auto_vectorize + + @property + def _map_fn(self): + if self.auto_vectorize: + return tf.vectorized_map + else: + return tf.map_fn + + @doc_controls.for_subclass_implementers + def augment_image(self, image, transformation): + """Augment a single image during training. + + Args: + image: 3D image input tensor to the layer. Forwarded from + `layer.call()`. + transformation: The transformation object produced by + `get_random_transformation`. Used to coordinate the randomness + between image, label and bounding box. + + Returns: + output 3D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + @doc_controls.for_subclass_implementers + def augment_label(self, label, transformation): + """Augment a single label during training. + + Args: + label: 1D label to the layer. Forwarded from `layer.call()`. + transformation: The transformation object produced by + `get_random_transformation`. Used to coordinate the randomness + between image, label and bounding box. + + Returns: + output 1D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + @doc_controls.for_subclass_implementers + def augment_target(self, target, transformation): + """Augment a single target during training. + + Args: + target: 1D label to the layer. Forwarded from `layer.call()`. + transformation: The transformation object produced by + `get_random_transformation`. Used to coordinate the randomness + between image, label and bounding box. + + Returns: + output 1D tensor, which will be forward to `layer.call()`. + """ + return self.augment_label(target, transformation) + + @doc_controls.for_subclass_implementers + def augment_bounding_boxes(self, image, bounding_boxes, transformation=None): + """Augment bounding boxes for one image during training. + + Args: + image: 3D image input tensor to the layer. Forwarded from + `layer.call()`. + bounding_boxes: 2D bounding boxes to the layer. Forwarded from + `call()`. + transformation: The transformation object produced by + `get_random_transformation`. Used to coordinate the randomness + between image, label and bounding box. + + Returns: + output 2D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + @doc_controls.for_subclass_implementers + def get_random_transformation(self, image=None, label=None, bounding_box=None): + """Produce random transformation config for one single input. + + This is used to produce same randomness between + image/label/bounding_box. + + Args: + image: 3D image tensor from inputs. + label: optional 1D label tensor from inputs. + bounding_box: optional 2D bounding boxes tensor from inputs. + + Returns: + Any type of object, which will be forwarded to `augment_image`, + `augment_label` and `augment_bounding_box` as the `transformation` + parameter. + """ + return None + + def call(self, inputs, training=True): + inputs = self._ensure_inputs_are_compute_dtype(inputs) + if training: + inputs, is_dict, use_targets = self._format_inputs(inputs) + images = inputs[IMAGES] + if images.shape.rank == 3: + return self._format_output(self._augment(inputs), is_dict, use_targets) + elif images.shape.rank == 4: + return self._format_output( + self._batch_augment(inputs), is_dict, use_targets + ) + else: + raise ValueError( + "Image augmentation layers are expecting inputs to be " + "rank 3 (HWC) or 4D (NHWC) tensors. Got shape: " + f"{images.shape}" + ) + else: + return inputs + + def _augment(self, inputs): + image = inputs.get(IMAGES, None) + label = inputs.get(LABELS, None) + bounding_box = inputs.get(BOUNDING_BOXES, None) + transformation = self.get_random_transformation( + image=image, label=label, bounding_box=bounding_box + ) + image = self.augment_image(image, transformation=transformation) + result = {IMAGES: image} + if label is not None: + label = self.augment_target(label, transformation=transformation) + result[LABELS] = label + if bounding_box is not None: + bounding_box = self.augment_bounding_boxes( + image, bounding_box, transformation=transformation + ) + result[BOUNDING_BOXES] = bounding_box + return result + + def _batch_augment(self, inputs): + return self._map_fn(self._augment, inputs) + + def _format_inputs(self, inputs): + if tf.is_tensor(inputs): + # single image input tensor + return {IMAGES: inputs}, False, False + elif isinstance(inputs, dict) and TARGETS in inputs: + # TODO(scottzhu): Check if it only contains the valid keys + inputs[LABELS] = inputs[TARGETS] + del inputs[TARGETS] + return inputs, True, True + elif isinstance(inputs, dict): + return inputs, True, False + else: + raise ValueError( + f"Expect the inputs to be image tensor or dict. Got {inputs}" + ) + + def _format_output(self, output, is_dict, use_targets): + if not is_dict: + return output[IMAGES] + elif use_targets: + output[TARGETS] = output[LABELS] + del output[LABELS] + return output + else: + return output + + def _ensure_inputs_are_compute_dtype(self, inputs): + if isinstance(inputs, dict): + inputs[IMAGES] = utils.ensure_tensor(inputs[IMAGES], self.compute_dtype) + else: + inputs = utils.ensure_tensor(inputs, self.compute_dtype) + return inputs diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py new file mode 100644 index 0000000000..760423cf6f --- /dev/null +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py @@ -0,0 +1,112 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import tensorflow as tf + +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) + + +class RandomAddLayer(BaseImageAugmentationLayer): + def __init__(self, value_range=(0.0, 1.0), fixed_value=None, **kwargs): + super().__init__(**kwargs) + self.value_range = value_range + self.fixed_value = fixed_value + + def get_random_transformation(self, image=None, label=None, bounding_box=None): + if self.fixed_value: + return self.fixed_value + return self._random_generator.random_uniform( + [], minval=self.value_range[0], maxval=self.value_range[1] + ) + + def augment_image(self, image, transformation): + return image + transformation + + def augment_label(self, label, transformation): + return label + transformation + + +class VectorizeDisabledLayer(BaseImageAugmentationLayer): + def __init__(self, **kwargs): + self.auto_vectorize = False + super().__init__(**kwargs) + + +class BaseImageAugmentationLayerTest(tf.test.TestCase): + def test_augment_single_image(self): + add_layer = RandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + output = add_layer(image) + + self.assertAllClose(image + 2.0, output) + + def test_augment_dict_return_type(self): + add_layer = RandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + output = add_layer({"images": image}) + + self.assertIsInstance(output, dict) + + def test_auto_vectorize_disabled(self): + vectorize_disabled_layer = VectorizeDisabledLayer() + self.assertFalse(vectorize_disabled_layer.auto_vectorize) + self.assertEqual(vectorize_disabled_layer._map_fn, tf.map_fn) + + def test_augment_casts_dtypes(self): + add_layer = RandomAddLayer(fixed_value=2.0) + images = tf.ones((2, 8, 8, 3), dtype="uint8") + output = add_layer(images) + + self.assertAllClose(tf.ones((2, 8, 8, 3), dtype="float32") * 3.0, output) + + def test_augment_batch_images(self): + add_layer = RandomAddLayer() + images = np.random.random(size=(2, 8, 8, 3)).astype("float32") + output = add_layer(images) + + diff = output - images + # Make sure the first image and second image get different augmentation + self.assertNotAllClose(diff[0], diff[1]) + + def test_augment_image_and_label(self): + add_layer = RandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + label = np.random.random(size=(1,)).astype("float32") + + output = add_layer({"images": image, "labels": label}) + expected_output = {"images": image + 2.0, "labels": label + 2.0} + self.assertAllClose(output, expected_output) + + def test_augment_image_and_target(self): + add_layer = RandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + label = np.random.random(size=(1,)).astype("float32") + + output = add_layer({"images": image, "targets": label}) + expected_output = {"images": image + 2.0, "targets": label + 2.0} + self.assertAllClose(output, expected_output) + + def test_augment_batch_images_and_labels(self): + add_layer = RandomAddLayer() + images = np.random.random(size=(2, 8, 8, 3)).astype("float32") + labels = np.random.random(size=(2, 1)).astype("float32") + output = add_layer({"images": images, "labels": labels}) + + image_diff = output["images"] - images + label_diff = output["labels"] - labels + # Make sure the first image and second image get different augmentation + self.assertNotAllClose(image_diff[0], image_diff[1]) + self.assertNotAllClose(label_diff[0], label_diff[1]) diff --git a/keras_cv/layers/preprocessing/cut_mix.py b/keras_cv/layers/preprocessing/cut_mix.py index 8bde22c9f3..aaba4e254e 100644 --- a/keras_cv/layers/preprocessing/cut_mix.py +++ b/keras_cv/layers/preprocessing/cut_mix.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import fill_utils @tf.keras.utils.register_keras_serializable(package="keras_cv") -class CutMix(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class CutMix(BaseImageAugmentationLayer): """CutMix implements the CutMix data augmentation technique. Args: diff --git a/keras_cv/layers/preprocessing/equalization.py b/keras_cv/layers/preprocessing/equalization.py index 0e29099f43..e74a330087 100644 --- a/keras_cv/layers/preprocessing/equalization.py +++ b/keras_cv/layers/preprocessing/equalization.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class Equalization(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class Equalization(BaseImageAugmentationLayer): """Equalization performs histogram equalization on a channel-wise basis. Args: diff --git a/keras_cv/layers/preprocessing/fourier_mix.py b/keras_cv/layers/preprocessing/fourier_mix.py index 38b97c9c4c..d64611e897 100644 --- a/keras_cv/layers/preprocessing/fourier_mix.py +++ b/keras_cv/layers/preprocessing/fourier_mix.py @@ -13,9 +13,13 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) + @tf.keras.utils.register_keras_serializable(package="keras_cv") -class FourierMix(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class FourierMix(BaseImageAugmentationLayer): """FourierMix implements the FMix data augmentation technique. Args: diff --git a/keras_cv/layers/preprocessing/grayscale.py b/keras_cv/layers/preprocessing/grayscale.py index 8ee7083b4c..debb3b99aa 100644 --- a/keras_cv/layers/preprocessing/grayscale.py +++ b/keras_cv/layers/preprocessing/grayscale.py @@ -14,9 +14,13 @@ import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) + @tf.keras.utils.register_keras_serializable(package="keras_cv") -class Grayscale(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class Grayscale(BaseImageAugmentationLayer): """Grayscale is a preprocessing layer that transforms RGB images to Grayscale images. Input images should have values in the range of [0, 255]. diff --git a/keras_cv/layers/preprocessing/grid_mask.py b/keras_cv/layers/preprocessing/grid_mask.py index c780829063..f0831f588b 100644 --- a/keras_cv/layers/preprocessing/grid_mask.py +++ b/keras_cv/layers/preprocessing/grid_mask.py @@ -16,6 +16,9 @@ from tensorflow.keras import layers from keras_cv import core +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import fill_utils from keras_cv.utils import preprocessing @@ -31,7 +34,7 @@ def _center_crop(mask, width, height): @tf.keras.utils.register_keras_serializable(package="keras_cv") -class GridMask(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class GridMask(BaseImageAugmentationLayer): """GridMask class for grid-mask augmentation. diff --git a/keras_cv/layers/preprocessing/mix_up.py b/keras_cv/layers/preprocessing/mix_up.py index 4c7946b2a1..451f5b465c 100644 --- a/keras_cv/layers/preprocessing/mix_up.py +++ b/keras_cv/layers/preprocessing/mix_up.py @@ -13,9 +13,13 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) + @tf.keras.utils.register_keras_serializable(package="keras_cv") -class MixUp(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class MixUp(BaseImageAugmentationLayer): """MixUp implements the MixUp data augmentation technique. Args: diff --git a/keras_cv/layers/preprocessing/posterization.py b/keras_cv/layers/preprocessing/posterization.py index 1589afd8cb..a9cd10bd8a 100644 --- a/keras_cv/layers/preprocessing/posterization.py +++ b/keras_cv/layers/preprocessing/posterization.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf -from tensorflow.keras.__internal__.layers import BaseImageAugmentationLayer +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils.preprocessing import transform_value_range diff --git a/keras_cv/layers/preprocessing/preprocessing_utils.py b/keras_cv/layers/preprocessing/preprocessing_utils.py new file mode 100644 index 0000000000..103cec6380 --- /dev/null +++ b/keras_cv/layers/preprocessing/preprocessing_utils.py @@ -0,0 +1,26 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for preprocessing layers.""" + +import tensorflow as tf + + +def ensure_tensor(inputs, dtype=None): + """Ensures the input is a Tensor, SparseTensor or RaggedTensor.""" + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)): + inputs = tf.convert_to_tensor(inputs, dtype) + if dtype is not None and inputs.dtype != dtype: + inputs = tf.cast(inputs, dtype) + return inputs diff --git a/keras_cv/layers/preprocessing/random_channel_shift.py b/keras_cv/layers/preprocessing/random_channel_shift.py index 4867f1e707..61f6517cb5 100644 --- a/keras_cv/layers/preprocessing/random_channel_shift.py +++ b/keras_cv/layers/preprocessing/random_channel_shift.py @@ -14,11 +14,14 @@ import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomChannelShift(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomChannelShift(BaseImageAugmentationLayer): """Randomly shift values for each channel of the input image(s). The input images should have values in the `[0-255]` or `[0-1]` range. diff --git a/keras_cv/layers/preprocessing/random_color_degeneration.py b/keras_cv/layers/preprocessing/random_color_degeneration.py index 2f3b64fc32..792dd37bf7 100644 --- a/keras_cv/layers/preprocessing/random_color_degeneration.py +++ b/keras_cv/layers/preprocessing/random_color_degeneration.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomColorDegeneration(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomColorDegeneration(BaseImageAugmentationLayer): """Randomly performs the color degeneration operation on given images. The sharpness operation first converts an image to gray scale, then back to color. diff --git a/keras_cv/layers/preprocessing/random_color_jitter.py b/keras_cv/layers/preprocessing/random_color_jitter.py index 06fc7efb24..8d2cd5a8fa 100644 --- a/keras_cv/layers/preprocessing/random_color_jitter.py +++ b/keras_cv/layers/preprocessing/random_color_jitter.py @@ -13,9 +13,11 @@ # limitations under the License. import tensorflow as tf -from tensorflow.keras.__internal__.layers import BaseImageAugmentationLayer from keras_cv.layers import preprocessing +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing as preprocessing_utils diff --git a/keras_cv/layers/preprocessing/random_cutout.py b/keras_cv/layers/preprocessing/random_cutout.py index 2d456814ad..fd52f99d75 100644 --- a/keras_cv/layers/preprocessing/random_cutout.py +++ b/keras_cv/layers/preprocessing/random_cutout.py @@ -13,12 +13,15 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import fill_utils from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomCutout(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomCutout(BaseImageAugmentationLayer): """Randomly cut out rectangles from images and fill them. Args: diff --git a/keras_cv/layers/preprocessing/random_gaussian_blur.py b/keras_cv/layers/preprocessing/random_gaussian_blur.py index 305d016b17..581dff26d5 100644 --- a/keras_cv/layers/preprocessing/random_gaussian_blur.py +++ b/keras_cv/layers/preprocessing/random_gaussian_blur.py @@ -14,11 +14,14 @@ import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomGaussianBlur(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomGaussianBlur(BaseImageAugmentationLayer): """Applies a Gaussian Blur with random strength to an image. Args: diff --git a/keras_cv/layers/preprocessing/random_hue.py b/keras_cv/layers/preprocessing/random_hue.py index e9edec3203..6512b4507d 100644 --- a/keras_cv/layers/preprocessing/random_hue.py +++ b/keras_cv/layers/preprocessing/random_hue.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomHue(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomHue(BaseImageAugmentationLayer): """Randomly adjusts the hue on given images. This layer will randomly increase/reduce the hue for the input RGB diff --git a/keras_cv/layers/preprocessing/random_saturation.py b/keras_cv/layers/preprocessing/random_saturation.py index b313d8cf49..cba1cda46f 100644 --- a/keras_cv/layers/preprocessing/random_saturation.py +++ b/keras_cv/layers/preprocessing/random_saturation.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomSaturation(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomSaturation(BaseImageAugmentationLayer): """Randomly adjusts the saturation on given images. This layer will randomly increase/reduce the saturation for the input RGB diff --git a/keras_cv/layers/preprocessing/random_sharpness.py b/keras_cv/layers/preprocessing/random_sharpness.py index 79fa0e166e..91ce9c228c 100644 --- a/keras_cv/layers/preprocessing/random_sharpness.py +++ b/keras_cv/layers/preprocessing/random_sharpness.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomSharpness(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomSharpness(BaseImageAugmentationLayer): """Randomly performs the sharpness operation on given images. The sharpness operation first performs a blur operation, then blends between the diff --git a/keras_cv/layers/preprocessing/random_shear.py b/keras_cv/layers/preprocessing/random_shear.py index afde8cf21c..4fcbc4921f 100644 --- a/keras_cv/layers/preprocessing/random_shear.py +++ b/keras_cv/layers/preprocessing/random_shear.py @@ -15,11 +15,14 @@ import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class RandomShear(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class RandomShear(BaseImageAugmentationLayer): """Randomly shears an image. Args: diff --git a/keras_cv/layers/preprocessing/solarization.py b/keras_cv/layers/preprocessing/solarization.py index 57b8ed8639..1e842c014f 100644 --- a/keras_cv/layers/preprocessing/solarization.py +++ b/keras_cv/layers/preprocessing/solarization.py @@ -13,11 +13,14 @@ # limitations under the License. import tensorflow as tf +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) from keras_cv.utils import preprocessing @tf.keras.utils.register_keras_serializable(package="keras_cv") -class Solarization(tf.keras.__internal__.layers.BaseImageAugmentationLayer): +class Solarization(BaseImageAugmentationLayer): """Applies (max_value - pixel + min_value) for each pixel in the image. When created without `threshold` parameter, the layer performs solarization to