Skip to content

Commit

Permalink
Merge pull request #187 from nipy/fix/memory-issues-173
Browse files Browse the repository at this point in the history
FIX: Postpone coordinate mapping on linear array transforms
  • Loading branch information
oesteban authored Nov 17, 2023
2 parents 6e70c02 + d148e85 commit 28737f4
Showing 1 changed file with 51 additions and 38 deletions.
89 changes: 51 additions & 38 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from nibabel.loadsave import load as _nbload
from nibabel.affines import from_matvec
from nibabel.arrayproxy import get_obj_dtype

from nitransforms.base import (
ImageGrid,
Expand Down Expand Up @@ -216,14 +217,13 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
is_array = cls != Affine
errors = []
for potential_fmt in fmtlist:
if (potential_fmt == "itk" and Path(filename).suffix == ".mat"):
if potential_fmt == "itk" and Path(filename).suffix == ".mat":
is_array = False
cls = Affine

try:
struct = get_linear_factory(
potential_fmt,
is_array=is_array
potential_fmt, is_array=is_array
).from_filename(filename)
except (TransformFileError, FileNotFoundError) as err:
errors.append((potential_fmt, err))
Expand Down Expand Up @@ -316,6 +316,11 @@ def __init__(self, transforms, reference=None):
)
self._inverse = np.linalg.inv(self._matrix)

def __iter__(self):
"""Enable iterating over the series of transforms."""
for _m in self.matrix:
yield Affine(_m, reference=self._reference)

def __getitem__(self, i):
"""Enable indexed access to the series of matrices."""
return Affine(self.matrix[i, ...], reference=self._reference)
Expand Down Expand Up @@ -436,6 +441,7 @@ def apply(
The data imaged after resampling to reference space.
"""

if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

Expand All @@ -446,40 +452,49 @@ def apply(
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.squeeze(np.asanyarray(spatialimage.dataobj))
output_dtype = output_dtype or data.dtype
# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

ycoords = self.map(_ref.ndcoords.T)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
)
# Prepare physical coordinates of input (grid, points)
xcoords = _ref.ndcoords.astype("f4").T

if data.ndim == 4:
if len(self) != data.shape[-1]:
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), data.shape[-1])
)
targets = targets.reshape((len(self), -1, targets.shape[-1]))
resampled = np.stack(
[
ndi.map_coordinates(
data[..., t],
targets[t, ..., : _ref.ndim].T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
for t in range(data.shape[-1])
],
axis=0,
# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)

if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), spatialimage.shape[-1])
)
elif data.ndim in (2, 3):
resampled = ndi.map_coordinates(
data,
targets[..., : _ref.ndim].T,

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.shape[0], len(self)), dtype=output_dtype, order="F"
)

dataobj = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
if spatialimage.ndim in (2, 3)
else None
)

for t, xfm_t in enumerate(self):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
dataobj
if dataobj is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
Expand All @@ -488,10 +503,8 @@ def apply(
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
newdata = resampled.reshape((len(self), *_ref.shape))
moved = spatialimage.__class__(
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
)
newdata = resampled.reshape(_ref.shape + (len(self),))
moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header)
moved.header.set_data_dtype(output_dtype)
return moved

Expand Down

0 comments on commit 28737f4

Please sign in to comment.