diff --git a/keras_preprocessing/image/affine_transformations.py b/keras_preprocessing/image/affine_transformations.py index c2103d8c..d5cd8e79 100644 --- a/keras_preprocessing/image/affine_transformations.py +++ b/keras_preprocessing/image/affine_transformations.py @@ -1,16 +1,10 @@ """Utilities for performing affine transformations on image data. """ import numpy as np +import tensorflow as tf from .utils import array_to_img, img_to_array -try: - import scipy - # scipy.ndimage cannot be accessed until explicitly imported - from scipy import ndimage -except ImportError: - scipy = None - try: from PIL import Image as pil_image from PIL import ImageEnhance @@ -41,8 +35,8 @@ def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - interpolation_order: int, order of spline interpolation. - see `ndimage.interpolation.affine_transform` + interpolation_order: int (one of `{0, 1}`) order of interpolation. + see `tfa.image.transform` # Returns Rotated Numpy image tensor. @@ -75,8 +69,8 @@ def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - interpolation_order: int, order of spline interpolation. - see `ndimage.interpolation.affine_transform` + interpolation_order: int (one of `{0, 1}`) order of interpolation. + see `tfa.image.transform` # Returns Shifted Numpy image tensor. @@ -111,8 +105,8 @@ def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - interpolation_order: int, order of spline interpolation. - see `ndimage.interpolation.affine_transform` + interpolation_order: int (one of `{0, 1}`) order of interpolation. + see `tfa.image.transform` # Returns Sheared Numpy image tensor. @@ -144,8 +138,8 @@ def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - interpolation_order: int, order of spline interpolation. - see `ndimage.interpolation.affine_transform` + interpolation_order: int (one of `{0, 1}`) order of interpolation. + see `tfa.image.transform` # Returns Zoomed Numpy image tensor. @@ -299,14 +293,23 @@ def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - order: int, order of interpolation + order: int (one of `{0, 1}`) order of interpolation. + see `tfa.image.transform` + + # Raises + ValueError if `raw_axis`, `col_axis` and `channel_axis` are misconfigured. # Returns The transformed version of the input. """ - if scipy is None: - raise ImportError('Image transformations require SciPy. ' - 'Install SciPy.') + + # Convert interpolation order into textual values used by tfa.image.transform. + if order == 0: + interpolation = "NEAREST" + elif order == 1: + interpolation = "BILINEAR" + else: + raise ValueError("Interpolation order can only be 0 or 1") # Input sanity checks: # 1. x must 2D image with one or more channels (i.e., a 3D tensor) @@ -367,30 +370,78 @@ def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1, h, w = x.shape[row_axis], x.shape[col_axis] transform_matrix = transform_matrix_offset_center( transform_matrix, h, w) - x = np.rollaxis(x, channel_axis, 0) + x = np.moveaxis(x, channel_axis, -1) # Matrix construction assumes that coordinates are x, y (in that order). - # However, regular numpy arrays use y,x (aka i,j) indexing. - # Possible solution is: + # However, users may reverse that order by setting `col_axis=0`, + # `row_axis=1`. In this case, one possible solution is: # 1. Swap the x and y axes. # 2. Apply transform. # 3. Swap the x and y axes again to restore image-like data ordering. # Mathematically, it is equivalent to the following transformation: # M' = PMP, where P is the permutation matrix, M is the original # transformation matrix. - if col_axis > row_axis: + if col_axis < row_axis: transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]] transform_matrix[[0, 1]] = transform_matrix[[1, 0]] - final_affine_matrix = transform_matrix[:2, :2] - final_offset = transform_matrix[:2, 2] - - channel_images = [ndimage.interpolation.affine_transform( - x_channel, - final_affine_matrix, - final_offset, - order=order, - mode=fill_mode, - cval=cval) for x_channel in x] - x = np.stack(channel_images, axis=0) - x = np.rollaxis(x, 0, channel_axis + 1) + w, h = h, w + + transform = matrix_to_transform(transform_matrix) + image = to_4D_tensor(x) + + image = tf.raw_ops.ImageProjectiveTransformV3( + images=image, + transforms=transform, + output_shape=(h, w), + interpolation=interpolation, + fill_mode=fill_mode.upper(), + fill_value=cval, + ) + x = from_4D_image(image, x.ndim) + x = np.moveaxis(x, -1, channel_axis) return x + +def matrix_to_transform(matrix): + transform = matrix.ravel()[0:8] + transform = tf.convert_to_tensor(transform, dtype=tf.dtypes.float32) + return transform[None] + + +def to_4D_tensor(image): + """Convert 2/3/4D image to 4D image. + + # Arguments + image: 2/3/4D `Tensor`. + + # Returns + 4D `Tensor` with the same type. + """ + image = tf.convert_to_tensor(image) + ndims = image.get_shape().ndims + + if ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + # Arguments + image: 4D `Tensor`. + ndims: The original rank of the image. + + # Returns + `ndims`-D `numpy.array` with the same type. + """ + + if ndims == 2: + res = tf.squeeze(image, [0, 3]) + elif ndims == 3: + res = tf.squeeze(image, [0]) + else: + res = image + return res.numpy() diff --git a/setup.py b/setup.py index bdbb88c3..9751845b 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ ''' setup(name='Keras_Preprocessing', - version='1.1.2', + version='1.1.3', description='Easy data preprocessing and data augmentation ' 'for deep learning models', long_description=long_description, @@ -33,13 +33,13 @@ extras_require={ 'tests': ['pandas', 'Pillow', - 'tensorflow', # CPU version + 'tensorflow>=2.4.0', # CPU version 'keras', 'pytest', 'pytest-xdist', 'pytest-cov'], 'pep8': ['flake8'], - 'image': ['scipy>=0.14', + 'image': ['tensorflow>=2.4.0' 'Pillow>=5.2.0'], }, classifiers=[ diff --git a/tests/image/affine_transformations_test.py b/tests/image/affine_transformations_test.py index 9a9079f4..df307e90 100644 --- a/tests/image/affine_transformations_test.py +++ b/tests/image/affine_transformations_test.py @@ -199,9 +199,3 @@ def test_random_brightness_scale_outside_range_negative(): assert np.array_equal(img, must_be_neg_1024) must_be_0 = affine_transformations.random_brightness(img, [1, 1], True) assert np.array_equal(zeros, must_be_0) - - -def test_apply_affine_transform_error(monkeypatch): - monkeypatch.setattr(affine_transformations, 'scipy', None) - with pytest.raises(ImportError): - affine_transformations.apply_affine_transform(0)