Skip to content

Commit

Permalink
enh: extend the nonlinear transforms API
Browse files Browse the repository at this point in the history
This PR lays the ground for future work on #56, and #89, by defining the
matrix multiplication operator on field-based transforms.
  • Loading branch information
oesteban committed Jul 19, 2022
1 parent e5a6b41 commit 33694fb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
4 changes: 4 additions & 0 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ def apply(

return resampled

def __matmul__(self, b):
"""Compose with a transform on the right."""
return b

def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand Down
90 changes: 80 additions & 10 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
ImageGrid,
SpatialReference,
_as_homogeneous,
EQUALITY_TOL,
)


class DisplacementsFieldTransform(TransformBase):
"""Represents a dense field of displacements (one vector per voxel)."""
class DeformationFieldTransform(TransformBase):
"""Represents a dense field of deformed locations (corresponding to each voxel)."""

__slots__ = ["_field"]

Expand All @@ -34,8 +35,8 @@ def __init__(self, field, reference=None):
Example
-------
>>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
<DisplacementFieldTransform[3D] (57, 67, 56)>
>>> DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
<DeformationFieldTransform[3D] (57, 67, 56)>
"""
super().__init__()
Expand All @@ -59,13 +60,13 @@ def __init__(self, field, reference=None):
ndim = self._field.ndim - 1
if self._field.shape[-1] != ndim:
raise TransformError(
"The number of components of the displacements (%d) does not "
"The number of components of the displacements (%d) does not match "
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
)

def __repr__(self):
"""Beautify the python representation."""
return f"<DisplacementFieldTransform[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"

def map(self, x, inverse=False):
r"""
Expand All @@ -92,12 +93,12 @@ def map(self, x, inverse=False):
Examples
--------
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm.map([-6.5, -36., -19.5]).tolist()
[[-6.5, -36.475167989730835, -19.5]]
[[0.0, -0.47516798973083496, 0.0]]
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
"""

Expand All @@ -108,7 +109,76 @@ def map(self, x, inverse=False):
if np.any(np.abs(ijk - indexes) > 0.05):
warnings.warn("Some coordinates are off-grid of the displacements field.")
indexes = tuple(tuple(i) for i in indexes.T)
return x + self._field[indexes]
return self._field[indexes]

def __matmul__(self, b):
"""
Compose with a transform on the right.
Examples
--------
>>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm2 = xfm @ TransformBase()
>>> xfm == xfm2
True
"""
retval = b.map(
self._field.reshape((-1, self._field.shape[-1]))
).reshape(self._field.shape)
return DeformationFieldTransform(retval, reference=self.reference)

def __eq__(self, other):
"""
Overload equals operator.
Examples
--------
>>> xfm1 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm2 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm1 == xfm2
True
"""
_eq = np.allclose(self._field, other._field, rtol=EQUALITY_TOL)
if _eq and self._reference != other._reference:
warnings.warn("Fields are equal, but references do not match.")
return _eq


class DisplacementsFieldTransform(DeformationFieldTransform):
"""
Represents a dense field of displacements (one vector per voxel).
Converting to a field of deformations is straightforward by just adding the corresponding
displacement to the :math:`(x, y, z)` coordinates of each voxel.
Numerically, deformation fields are less susceptible to rounding errors
than displacements fields.
SPM generally prefers deformations for that reason.
"""

__slots__ = ["_displacements"]

def __init__(self, field, reference=None):
"""
Create a transform supported by a field of voxel-wise displacements.
Example
-------
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm
<DisplacementsFieldTransform[3D] (57, 67, 56)>
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
"""
super().__init__(field, reference=reference)
self._displacements = self._field
# Convert from displacements to deformations fields
# (just add the origin to the displacements vector)
self._field += self.reference.ndcoords.T.reshape(self._field.shape)

@classmethod
def from_filename(cls, filename, fmt="X5"):
Expand Down

0 comments on commit 33694fb

Please sign in to comment.