From a71870113de5880cb49832bf8f767aee674d99c4 Mon Sep 17 00:00:00 2001 From: quantumalaviya <45961148+quantumalaviya@users.noreply.github.com> Date: Wed, 6 Apr 2022 00:07:35 +0530 Subject: [PATCH] Vectorize Equalization (#201) * Port Equalization to BaseImageAugmentationLayer (non-vectorized) * Vectorizing Equalization * Vectorize Equalization * Vectorize Equalization * added changes * Update equalization.py * introduce equalization correctness tests * Reformat equalization Co-authored-by: Luke Wood --- keras_cv/layers/preprocessing/equalization.py | 35 ++++++++++++------- .../layers/preprocessing/equalization_test.py | 30 +++++++++++++++- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/keras_cv/layers/preprocessing/equalization.py b/keras_cv/layers/preprocessing/equalization.py index 3ab159ae19..a26cdebd94 100644 --- a/keras_cv/layers/preprocessing/equalization.py +++ b/keras_cv/layers/preprocessing/equalization.py @@ -60,9 +60,17 @@ def equalize_channel(self, image, channel_index): histogram = tf.histogram_fixed_width(image, [0, 255], nbins=self.bins) # For the purposes of computing the step, filter out the nonzeros. - nonzero = tf.where(tf.not_equal(histogram, 0)) - nonzero_histogram = tf.reshape(tf.gather(histogram, nonzero), [-1]) - step = (tf.reduce_sum(nonzero_histogram) - nonzero_histogram[-1]) // ( + # Zeroes are replaced by a big number while calculating min to keep shape + # constant across input sizes for compatibility with vectorized_map + + big_number = 1410065408 + histogram_without_zeroes = tf.where( + tf.equal(histogram, 0), + big_number, + histogram, + ) + + step = (tf.reduce_sum(histogram) - tf.reduce_min(histogram_without_zeroes)) // ( self.bins - 1 ) @@ -81,20 +89,23 @@ def build_mapping(histogram, step): result = tf.cond( tf.equal(step, 0), lambda: image, - lambda: tf.cast( - tf.gather(build_mapping(histogram, step), tf.cast(image, tf.int32)), - self.compute_dtype, - ), + lambda: tf.gather(build_mapping(histogram, step), image), ) return result def augment_image(self, image, transformation=None): - image = preprocessing.transform_value_range(image, self.value_range, (0, 255)) - r = self.equalize_channel(image, 0) - g = self.equalize_channel(image, 1) - b = self.equalize_channel(image, 2) - image = tf.stack([r, g, b], axis=-1) + image = preprocessing.transform_value_range( + image, self.value_range, (0, 255), dtype=image.dtype + ) + image = tf.cast(image, tf.int32) + image = tf.vectorized_map( + lambda channel: self.equalize_channel(image, channel), + tf.range(tf.shape(image)[-1]), + ) + + image = tf.transpose(image, [1, 2, 0]) + image = tf.cast(image, tf.float32) image = preprocessing.transform_value_range(image, (0, 255), self.value_range) return image diff --git a/keras_cv/layers/preprocessing/equalization_test.py b/keras_cv/layers/preprocessing/equalization_test.py index ac6da6f89d..b67201d2b0 100644 --- a/keras_cv/layers/preprocessing/equalization_test.py +++ b/keras_cv/layers/preprocessing/equalization_test.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf +from absl.testing import parameterized from keras_cv.layers.preprocessing.equalization import Equalization -class EqualizationTest(tf.test.TestCase): +class EqualizationTest(tf.test.TestCase, parameterized.TestCase): def test_return_shapes(self): xs = 255 * tf.ones((2, 512, 512, 3), dtype=tf.int32) layer = Equalization(value_range=(0, 255)) @@ -24,3 +25,30 @@ def test_return_shapes(self): self.assertEqual(xs.shape, [2, 512, 512, 3]) self.assertAllEqual(xs, 255 * tf.ones((2, 512, 512, 3))) + + def test_equalizes_to_all_bins(self): + xs = tf.random.uniform((2, 512, 512, 3), 0, 255, dtype=tf.float32) + layer = Equalization(value_range=(0, 255)) + xs = layer(xs) + + for i in range(0, 256): + self.assertTrue(tf.math.reduce_any(xs == i)) + + @parameterized.named_parameters( + ("float32", tf.float32), ("int32", tf.int32), ("int64", tf.int64) + ) + def test_input_dtypes(self, dtype): + xs = tf.random.uniform((2, 512, 512, 3), 0, 255, dtype=dtype) + layer = Equalization(value_range=(0, 255)) + xs = layer(xs) + + for i in range(0, 256): + self.assertTrue(tf.math.reduce_any(xs == i)) + self.assertAllInRange(xs, 0, 255) + + @parameterized.named_parameters(("0_255", 0, 255), ("0_1", 0, 1)) + def test_output_range(self, lower, upper): + xs = tf.random.uniform((2, 512, 512, 3), lower, upper, dtype=tf.float32) + layer = Equalization(value_range=(lower, upper)) + xs = layer(xs) + self.assertAllInRange(xs, lower, upper)