Skip to content

Commit

Permalink
Vectorize Equalization (#201)
Browse files Browse the repository at this point in the history
* Port Equalization to BaseImageAugmentationLayer (non-vectorized)

* Vectorizing Equalization

* Vectorize Equalization

* Vectorize Equalization

* added changes

* Update equalization.py

* introduce equalization correctness tests

* Reformat equalization

Co-authored-by: Luke Wood <[email protected]>
  • Loading branch information
quantumalaviya and LukeWood authored Apr 5, 2022
1 parent 0131d55 commit 72a774d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
35 changes: 23 additions & 12 deletions keras_cv/layers/preprocessing/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,17 @@ def equalize_channel(self, image, channel_index):
histogram = tf.histogram_fixed_width(image, [0, 255], nbins=self.bins)

# For the purposes of computing the step, filter out the nonzeros.
nonzero = tf.where(tf.not_equal(histogram, 0))
nonzero_histogram = tf.reshape(tf.gather(histogram, nonzero), [-1])
step = (tf.reduce_sum(nonzero_histogram) - nonzero_histogram[-1]) // (
# Zeroes are replaced by a big number while calculating min to keep shape
# constant across input sizes for compatibility with vectorized_map

big_number = 1410065408
histogram_without_zeroes = tf.where(
tf.equal(histogram, 0),
big_number,
histogram,
)

step = (tf.reduce_sum(histogram) - tf.reduce_min(histogram_without_zeroes)) // (
self.bins - 1
)

Expand All @@ -81,20 +89,23 @@ def build_mapping(histogram, step):
result = tf.cond(
tf.equal(step, 0),
lambda: image,
lambda: tf.cast(
tf.gather(build_mapping(histogram, step), tf.cast(image, tf.int32)),
self.compute_dtype,
),
lambda: tf.gather(build_mapping(histogram, step), image),
)

return result

def augment_image(self, image, transformation=None):
image = preprocessing.transform_value_range(image, self.value_range, (0, 255))
r = self.equalize_channel(image, 0)
g = self.equalize_channel(image, 1)
b = self.equalize_channel(image, 2)
image = tf.stack([r, g, b], axis=-1)
image = preprocessing.transform_value_range(
image, self.value_range, (0, 255), dtype=image.dtype
)
image = tf.cast(image, tf.int32)
image = tf.vectorized_map(
lambda channel: self.equalize_channel(image, channel),
tf.range(tf.shape(image)[-1]),
)

image = tf.transpose(image, [1, 2, 0])
image = tf.cast(image, tf.float32)
image = preprocessing.transform_value_range(image, (0, 255), self.value_range)
return image

Expand Down
30 changes: 29 additions & 1 deletion keras_cv/layers/preprocessing/equalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl.testing import parameterized

from keras_cv.layers.preprocessing.equalization import Equalization


class EqualizationTest(tf.test.TestCase):
class EqualizationTest(tf.test.TestCase, parameterized.TestCase):
def test_return_shapes(self):
xs = 255 * tf.ones((2, 512, 512, 3), dtype=tf.int32)
layer = Equalization(value_range=(0, 255))
xs = layer(xs)

self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertAllEqual(xs, 255 * tf.ones((2, 512, 512, 3)))

def test_equalizes_to_all_bins(self):
xs = tf.random.uniform((2, 512, 512, 3), 0, 255, dtype=tf.float32)
layer = Equalization(value_range=(0, 255))
xs = layer(xs)

for i in range(0, 256):
self.assertTrue(tf.math.reduce_any(xs == i))

@parameterized.named_parameters(
("float32", tf.float32), ("int32", tf.int32), ("int64", tf.int64)
)
def test_input_dtypes(self, dtype):
xs = tf.random.uniform((2, 512, 512, 3), 0, 255, dtype=dtype)
layer = Equalization(value_range=(0, 255))
xs = layer(xs)

for i in range(0, 256):
self.assertTrue(tf.math.reduce_any(xs == i))
self.assertAllInRange(xs, 0, 255)

@parameterized.named_parameters(("0_255", 0, 255), ("0_1", 0, 1))
def test_output_range(self, lower, upper):
xs = tf.random.uniform((2, 512, 512, 3), lower, upper, dtype=tf.float32)
layer = Equalization(value_range=(lower, upper))
xs = layer(xs)
self.assertAllInRange(xs, lower, upper)

0 comments on commit 72a774d

Please sign in to comment.