diff --git a/nitransforms/base.py b/nitransforms/base.py index 9a1600a0..87bda079 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -303,6 +303,10 @@ def apply( return resampled + def __matmul__(self, b): + """Compose with a transform on the right.""" + return b + def map(self, x, inverse=False): r""" Apply :math:`y = f(x)`. diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 0c9bb371..b2bd5ad8 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -20,11 +20,12 @@ ImageGrid, SpatialReference, _as_homogeneous, + EQUALITY_TOL, ) -class DisplacementsFieldTransform(TransformBase): - """Represents a dense field of displacements (one vector per voxel).""" +class DeformationFieldTransform(TransformBase): + """Represents a dense field of deformed locations (corresponding to each voxel).""" __slots__ = ["_field"] @@ -34,8 +35,8 @@ def __init__(self, field, reference=None): Example ------- - >>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz") - + >>> DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz") + """ super().__init__() @@ -59,13 +60,13 @@ def __init__(self, field, reference=None): ndim = self._field.ndim - 1 if self._field.shape[-1] != ndim: raise TransformError( - "The number of components of the displacements (%d) does not " + "The number of components of the displacements (%d) does not match " "the number of dimensions (%d)" % (self._field.shape[-1], ndim) ) def __repr__(self): """Beautify the python representation.""" - return f"" + return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>" def map(self, x, inverse=False): r""" @@ -92,12 +93,12 @@ def map(self, x, inverse=False): Examples -------- - >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz") + >>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz") >>> xfm.map([-6.5, -36., -19.5]).tolist() - [[-6.5, -36.475167989730835, -19.5]] + [[0.0, -0.47516798973083496, 0.0]] >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() - [[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]] + [[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]] """ @@ -108,7 +109,76 @@ def map(self, x, inverse=False): if np.any(np.abs(ijk - indexes) > 0.05): warnings.warn("Some coordinates are off-grid of the displacements field.") indexes = tuple(tuple(i) for i in indexes.T) - return x + self._field[indexes] + return self._field[indexes] + + def __matmul__(self, b): + """ + Compose with a transform on the right. + + Examples + -------- + >>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz") + >>> xfm2 = xfm @ TransformBase() + >>> xfm == xfm2 + True + + """ + retval = b.map( + self._field.reshape((-1, self._field.shape[-1])) + ).reshape(self._field.shape) + return DeformationFieldTransform(retval, reference=self.reference) + + def __eq__(self, other): + """ + Overload equals operator. + + Examples + -------- + >>> xfm1 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz") + >>> xfm2 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz") + >>> xfm1 == xfm2 + True + + """ + _eq = np.allclose(self._field, other._field, rtol=EQUALITY_TOL) + if _eq and self._reference != other._reference: + warnings.warn("Fields are equal, but references do not match.") + return _eq + + +class DisplacementsFieldTransform(DeformationFieldTransform): + """ + Represents a dense field of displacements (one vector per voxel). + + Converting to a field of deformations is straightforward by just adding the corresponding + displacement to the :math:`(x, y, z)` coordinates of each voxel. + Numerically, deformation fields are less susceptible to rounding errors + than displacements fields. + SPM generally prefers deformations for that reason. + + """ + + __slots__ = ["_displacements"] + + def __init__(self, field, reference=None): + """ + Create a transform supported by a field of voxel-wise displacements. + + Example + ------- + >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz") + >>> xfm + + + >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() + [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]] + + """ + super().__init__(field, reference=reference) + self._displacements = self._field + # Convert from displacements to deformations fields + # (just add the origin to the displacements vector) + self._field += self.reference.ndcoords.T.reshape(self._field.shape) @classmethod def from_filename(cls, filename, fmt="X5"):