Skip to content

Commit

Permalink
ENH: Collapse linear and nonlinear transforms chains
Browse files Browse the repository at this point in the history
Very undertested, but currently there is a test that uses a "collapsed"
transform on an ITK's .h5 file with one affine and one nonlinear.

BSpline transforms not currently supported.

Resolves #89.
  • Loading branch information
oesteban committed Jul 20, 2022
1 parent ef5a28f commit d25308d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
16 changes: 7 additions & 9 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,17 @@ def __matmul__(self, b):
True
>>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
>>> xfm1 @ np.eye(4) == xfm1
>>> xfm1 @ Affine() == xfm1
True
"""
if not isinstance(b, self.__class__):
_b = self.__class__(b)
else:
_b = b
if isinstance(b, self.__class__):
return self.__class__(
b.matrix @ self.matrix,
reference=b.reference,
)

retval = self.__class__(self.matrix.dot(_b.matrix))
if _b.reference:
retval.reference = _b.reference
return retval
return b @ self

@property
def matrix(self):
Expand Down
15 changes: 7 additions & 8 deletions nitransforms/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ def map(self, x, inverse=False):

return x

def asaffine(self, indices=None):
def collapse(self):
"""
Combine a succession of linear transforms into one.
Combine a succession of transforms into one.
Example
------
>>> chain = TransformChain(transforms=[
... Affine.from_matvec(vec=(2, -10, 3)),
... Affine.from_matvec(vec=(-2, 10, -3)),
... ])
>>> chain.asaffine()
>>> chain.collapse()
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
Expand All @@ -160,15 +160,15 @@ def asaffine(self, indices=None):
... Affine.from_matvec(vec=(1, 2, 3)),
... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]),
... ])
>>> chain.asaffine()
>>> chain.collapse()
array([[0., 1., 0., 2.],
[0., 0., 1., 3.],
[1., 0., 0., 1.],
[0., 0., 0., 1.]])
>>> np.allclose(
... chain.map((4, -2, 1)),
... chain.asaffine().map((4, -2, 1)),
... chain.collapse().map((4, -2, 1)),
... )
True
Expand All @@ -178,9 +178,8 @@ def asaffine(self, indices=None):
The indices of the values to extract.
"""
affines = self.transforms if indices is None else np.take(self.transforms, indices)
retval = affines[0]
for xfm in affines[1:]:
retval = self.transforms[-1]
for xfm in reversed(self.transforms[:-1]):
retval = xfm @ retval
return retval

Expand Down
6 changes: 3 additions & 3 deletions nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,10 @@ def test_mulmat_operator(testdata_path):
mat2 = from_matvec(np.eye(3), (4, 2, -1))
aff = nitl.Affine(mat1, reference=ref)

composed = aff @ mat2
composed = aff @ nitl.Affine(mat2)
assert composed.reference is None
assert composed == nitl.Affine(mat1.dot(mat2))
assert composed == nitl.Affine(mat2 @ mat1)

composed = nitl.Affine(mat2) @ aff
assert composed.reference == aff.reference
assert composed == nitl.Affine(mat2.dot(mat1), reference=ref)
assert composed == nitl.Affine(mat1 @ mat2, reference=ref)
8 changes: 7 additions & 1 deletion nitransforms/tests/test_manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def test_itk_h5(tmp_path, testdata_path):
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL

col_moved = xfm.collapse().apply(img_fname, order=0)
col_moved.to_filename("nt_collapse_resampled.nii.gz")
diff = sw_moved.get_fdata() - col_moved.get_fdata()
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL


@pytest.mark.parametrize("ext0", ["lta", "tfm"])
@pytest.mark.parametrize("ext1", ["lta", "tfm"])
Expand All @@ -81,7 +87,7 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2):
]
)
assert np.allclose(
chain.asaffine().matrix,
chain.collapse().matrix,
Affine.from_filename(
data_path / "regressions" / f"from-fsnative_to-bold_mode-image.{ext2}",
fmt=f"{FMT[ext2]}",
Expand Down

0 comments on commit d25308d

Please sign in to comment.