Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Removed SciPy dependency. #334

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 86 additions & 35 deletions keras_preprocessing/image/affine_transformations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
"""Utilities for performing affine transformations on image data.
"""
import numpy as np
import tensorflow as tf

from .utils import array_to_img, img_to_array

try:
import scipy
# scipy.ndimage cannot be accessed until explicitly imported
from scipy import ndimage
except ImportError:
scipy = None

try:
from PIL import Image as pil_image
from PIL import ImageEnhance
Expand Down Expand Up @@ -41,8 +35,8 @@ def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
interpolation_order: int, order of spline interpolation.
see `ndimage.interpolation.affine_transform`
interpolation_order: int (one of `{0, 1}`) order of interpolation.
see `tfa.image.transform`

# Returns
Rotated Numpy image tensor.
Expand Down Expand Up @@ -75,8 +69,8 @@ def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
interpolation_order: int, order of spline interpolation.
see `ndimage.interpolation.affine_transform`
interpolation_order: int (one of `{0, 1}`) order of interpolation.
see `tfa.image.transform`

# Returns
Shifted Numpy image tensor.
Expand Down Expand Up @@ -111,8 +105,8 @@ def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
interpolation_order: int, order of spline interpolation.
see `ndimage.interpolation.affine_transform`
interpolation_order: int (one of `{0, 1}`) order of interpolation.
see `tfa.image.transform`

# Returns
Sheared Numpy image tensor.
Expand Down Expand Up @@ -144,8 +138,8 @@ def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
interpolation_order: int, order of spline interpolation.
see `ndimage.interpolation.affine_transform`
interpolation_order: int (one of `{0, 1}`) order of interpolation.
see `tfa.image.transform`

# Returns
Zoomed Numpy image tensor.
Expand Down Expand Up @@ -299,14 +293,23 @@ def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
order: int, order of interpolation
order: int (one of `{0, 1}`) order of interpolation.
see `tfa.image.transform`

# Raises
ValueError if `raw_axis`, `col_axis` and `channel_axis` are misconfigured.

# Returns
The transformed version of the input.
"""
if scipy is None:
raise ImportError('Image transformations require SciPy. '
'Install SciPy.')

# Convert interpolation order into textual values used by tfa.image.transform.
if order == 0:
interpolation = "NEAREST"
elif order == 1:
interpolation = "BILINEAR"
else:
raise ValueError("Interpolation order can only be 0 or 1")

# Input sanity checks:
# 1. x must 2D image with one or more channels (i.e., a 3D tensor)
Expand Down Expand Up @@ -367,30 +370,78 @@ def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
h, w = x.shape[row_axis], x.shape[col_axis]
transform_matrix = transform_matrix_offset_center(
transform_matrix, h, w)
x = np.rollaxis(x, channel_axis, 0)
x = np.moveaxis(x, channel_axis, -1)

# Matrix construction assumes that coordinates are x, y (in that order).
# However, regular numpy arrays use y,x (aka i,j) indexing.
# Possible solution is:
# However, users may reverse that order by setting `col_axis=0`,
# `row_axis=1`. In this case, one possible solution is:
# 1. Swap the x and y axes.
# 2. Apply transform.
# 3. Swap the x and y axes again to restore image-like data ordering.
# Mathematically, it is equivalent to the following transformation:
# M' = PMP, where P is the permutation matrix, M is the original
# transformation matrix.
if col_axis > row_axis:
if col_axis < row_axis:
transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]]
transform_matrix[[0, 1]] = transform_matrix[[1, 0]]
final_affine_matrix = transform_matrix[:2, :2]
final_offset = transform_matrix[:2, 2]

channel_images = [ndimage.interpolation.affine_transform(
x_channel,
final_affine_matrix,
final_offset,
order=order,
mode=fill_mode,
cval=cval) for x_channel in x]
x = np.stack(channel_images, axis=0)
x = np.rollaxis(x, 0, channel_axis + 1)
w, h = h, w

transform = matrix_to_transform(transform_matrix)
image = to_4D_tensor(x)

image = tf.raw_ops.ImageProjectiveTransformV3(
images=image,
transforms=transform,
output_shape=(h, w),
interpolation=interpolation,
fill_mode=fill_mode.upper(),
fill_value=cval,
)
x = from_4D_image(image, x.ndim)
x = np.moveaxis(x, -1, channel_axis)
return x

def matrix_to_transform(matrix):
transform = matrix.ravel()[0:8]
transform = tf.convert_to_tensor(transform, dtype=tf.dtypes.float32)
return transform[None]


def to_4D_tensor(image):
"""Convert 2/3/4D image to 4D image.

# Arguments
image: 2/3/4D `Tensor`.

# Returns
4D `Tensor` with the same type.
"""
image = tf.convert_to_tensor(image)
ndims = image.get_shape().ndims

if ndims == 2:
return image[None, :, :, None]
elif ndims == 3:
return image[None, :, :, :]
else:
return image


def from_4D_image(image, ndims):
"""Convert back to an image with `ndims` rank.

# Arguments
image: 4D `Tensor`.
ndims: The original rank of the image.

# Returns
`ndims`-D `numpy.array` with the same type.
"""

if ndims == 2:
res = tf.squeeze(image, [0, 3])
elif ndims == 3:
res = tf.squeeze(image, [0])
else:
res = image
return res.numpy()
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'''

setup(name='Keras_Preprocessing',
version='1.1.2',
version='1.1.3',
description='Easy data preprocessing and data augmentation '
'for deep learning models',
long_description=long_description,
Expand All @@ -33,13 +33,13 @@
extras_require={
'tests': ['pandas',
'Pillow',
'tensorflow', # CPU version
'tensorflow>=2.4.0', # CPU version
'keras',
'pytest',
'pytest-xdist',
'pytest-cov'],
'pep8': ['flake8'],
'image': ['scipy>=0.14',
'image': ['tensorflow>=2.4.0'
'Pillow>=5.2.0'],
},
classifiers=[
Expand Down
6 changes: 0 additions & 6 deletions tests/image/affine_transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,3 @@ def test_random_brightness_scale_outside_range_negative():
assert np.array_equal(img, must_be_neg_1024)
must_be_0 = affine_transformations.random_brightness(img, [1, 1], True)
assert np.array_equal(zeros, must_be_0)


def test_apply_affine_transform_error(monkeypatch):
monkeypatch.setattr(affine_transformations, 'scipy', None)
with pytest.raises(ImportError):
affine_transformations.apply_affine_transform(0)