Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Postpone coordinate mapping on linear array transforms #187

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 51 additions & 38 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from nibabel.loadsave import load as _nbload
from nibabel.affines import from_matvec
from nibabel.arrayproxy import get_obj_dtype

from nitransforms.base import (
ImageGrid,
Expand Down Expand Up @@ -216,14 +217,13 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
is_array = cls != Affine
errors = []
for potential_fmt in fmtlist:
if (potential_fmt == "itk" and Path(filename).suffix == ".mat"):
if potential_fmt == "itk" and Path(filename).suffix == ".mat":
is_array = False
cls = Affine

try:
struct = get_linear_factory(
potential_fmt,
is_array=is_array
potential_fmt, is_array=is_array
).from_filename(filename)
except (TransformFileError, FileNotFoundError) as err:
errors.append((potential_fmt, err))
Expand Down Expand Up @@ -316,6 +316,11 @@ def __init__(self, transforms, reference=None):
)
self._inverse = np.linalg.inv(self._matrix)

def __iter__(self):
"""Enable iterating over the series of transforms."""
for _m in self.matrix:
yield Affine(_m, reference=self._reference)

def __getitem__(self, i):
"""Enable indexed access to the series of matrices."""
return Affine(self.matrix[i, ...], reference=self._reference)
Expand Down Expand Up @@ -436,6 +441,7 @@ def apply(
The data imaged after resampling to reference space.

"""

if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

Expand All @@ -446,40 +452,49 @@ def apply(
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.squeeze(np.asanyarray(spatialimage.dataobj))
output_dtype = output_dtype or data.dtype
# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

ycoords = self.map(_ref.ndcoords.T)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
)
# Prepare physical coordinates of input (grid, points)
xcoords = _ref.ndcoords.astype("f4").T

if data.ndim == 4:
if len(self) != data.shape[-1]:
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), data.shape[-1])
)
targets = targets.reshape((len(self), -1, targets.shape[-1]))
resampled = np.stack(
[
ndi.map_coordinates(
data[..., t],
targets[t, ..., : _ref.ndim].T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
for t in range(data.shape[-1])
],
axis=0,
# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)

if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), spatialimage.shape[-1])
)
elif data.ndim in (2, 3):
resampled = ndi.map_coordinates(
data,
targets[..., : _ref.ndim].T,

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.shape[0], len(self)), dtype=output_dtype, order="F"
)

dataobj = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
if spatialimage.ndim in (2, 3)
else None
)

for t, xfm_t in enumerate(self):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
dataobj
if dataobj is not None
else np.asanyarray(spatialimage.dataobj[..., t], dtype=input_dtype)
oesteban marked this conversation as resolved.
Show resolved Hide resolved
),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
Expand All @@ -488,10 +503,8 @@ def apply(
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
newdata = resampled.reshape((len(self), *_ref.shape))
moved = spatialimage.__class__(
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
)
newdata = resampled.reshape(_ref.shape + (len(self),))
moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header)
moved.header.set_data_dtype(output_dtype)
return moved

Expand Down
Loading