Skip to content

Commit

Permalink
sty: add typing annotations and run black
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed May 2, 2023
1 parent 88405a1 commit f437f69
Showing 1 changed file with 59 additions and 36 deletions.
95 changes: 59 additions & 36 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#
"""The :math:`B_0` unwarping transform formalism."""
from pathlib import Path
from typing import Sequence, Union

import attr
import numpy as np
Expand Down Expand Up @@ -50,7 +51,12 @@ class B0FieldTransform:
target image we want to correct).
"""

def fit(self, target_reference, affine=None, approx=True):
def fit(
self,
target_reference: nb.spatialimages.SpatialImage,
affine: np.array = None,
approx: bool = True,
) -> bool:
r"""
Generate the interpolation matrix (and the VSM with it).
Expand All @@ -59,7 +65,7 @@ def fit(self, target_reference, affine=None, approx=True):
Parameters
----------
target_reference : `spatialimage`
target_reference : :obj:`~nibabel.spatialimages.SpatialImage`
The image object containing a reference grid (same as that of the data
to be resampled). If a 4D dataset is provided, then the fourth dimension
will be dropped.
Expand All @@ -83,7 +89,9 @@ def fit(self, target_reference, affine=None, approx=True):
if isinstance(target_reference, (str, bytes, Path)):
target_reference = nb.load(target_reference)

approx = approx if affine is not None else False # Approximate iff affine is defined
approx = (
approx if affine is not None else False
) # Approximate iff affine is defined
affine = affine if affine is not None else np.eye(4)
target_affine = target_reference.affine.copy()

Expand Down Expand Up @@ -141,22 +149,22 @@ def fit(self, target_reference, affine=None, approx=True):
hdr.set_intent("estimate", name="fieldmap Hz")
hdr.set_data_dtype("float32")
hdr["cal_max"] = max((abs(fmap.min()), fmap.max()))
hdr["cal_min"] = - hdr["cal_max"]
hdr["cal_min"] = -hdr["cal_max"]
self.mapped = nb.Nifti1Image(fmap, target_affine, hdr)
return True

def apply(
self,
moving,
pe_dir,
ro_time,
xfms=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
num_threads=None,
moving: nb.spatialimages.SpatialImage,
pe_dir: str,
ro_time: float,
xfms: Sequence[np.array] = None,
order: int = 3,
mode: str = "constant",
cval: float = 0.0,
prefilter: bool = True,
output_dtype: Union[str, np.dtype] = None,
num_threads: int = None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.
Expand All @@ -165,32 +173,41 @@ def apply(
Parameters
----------
moving : `spatialimage`
moving : :obj:`~nibabel.spatialimages.SpatialImage`
The image object containing the data to be resampled in reference
space
xfms : `None` or :obj:`list`
pe_dir : :obj:`str`
A valid ``PhaseEncodingDirection`` metadata value.
ro_time : :obj:`float`
The total readout time in seconds.
xfms : :obj:`None` or :obj:`list`
A list of rigid-body transformations previously estimated that will
realign the dataset (that is, compensate for head motion) after resampling.
order : int, optional
order : :obj:`int`, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional
Determines how the input image is extended when the resamplings overflows
a border. Default is 'constant'.
cval : float, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: bool, optional
prefilter : :obj:`bool`, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
output_dtype : :obj:`str` or :obj:`~numpy.dtype`
Override the output data type, instead of propagating it from the
moving image.
num_threads : :obj:`int`
Number of CPUs resampling can be parallelized on.
Returns
-------
resampled : `spatialimage` or ndarray
resampled : :obj:`~nibabel.spatialimages.SpatialImage`
The data imaged after resampling to reference space.
"""
Expand Down Expand Up @@ -249,9 +266,7 @@ def apply(
prefilter=prefilter,
).reshape(moving.shape)

moved = moving.__class__(
resampled, moving.affine, moving.header
)
moved = moving.__class__(resampled, moving.affine, moving.header)
moved.header.set_data_dtype(output_dtype)
return reorient_image(moved, axcodes)

Expand Down Expand Up @@ -368,11 +383,13 @@ def disp_to_fmap(xyz_nii, ro_time, pe_dir, itk_format=True):
fmap_nii = nb.Nifti1Image(vsm / scale_factor, xyz_nii.affine)
fmap_nii.header.set_intent("estimate", name="Delta_B0 [Hz]")
fmap_nii.header.set_xyzt_units("mm")
fmap_nii.header["cal_max"] = max((
abs(np.asanyarray(fmap_nii.dataobj).min()),
np.asanyarray(fmap_nii.dataobj).max(),
))
fmap_nii.header["cal_min"] = - fmap_nii.header["cal_max"]
fmap_nii.header["cal_max"] = max(
(
abs(np.asanyarray(fmap_nii.dataobj).min()),
np.asanyarray(fmap_nii.dataobj).max(),
)
)
fmap_nii.header["cal_min"] = -fmap_nii.header["cal_max"]
return fmap_nii


Expand Down Expand Up @@ -426,10 +443,14 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
knots_shape = ctrl_nii.shape[:3]

# Ensure the cross-product of affines is near zero (i.e., both coordinate systems are aligned)
if not np.allclose(np.linalg.norm(
np.cross(ctrl_nii.affine[:-1, :-1].T, target_nii.affine[:-1, :-1].T),
axis=1,
), 0, atol=1e-3):
if not np.allclose(
np.linalg.norm(
np.cross(ctrl_nii.affine[:-1, :-1].T, target_nii.affine[:-1, :-1].T),
axis=1,
),
0,
atol=1e-3,
):
warn("Image's and B-Spline's grids are not aligned.")

target_to_grid = np.linalg.inv(ctrl_nii.affine) @ target_nii.affine
Expand Down Expand Up @@ -481,9 +502,11 @@ def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None):
hdr.set_sform(newaff, code=1)

# Make it easy on viz software to render proper range
hdr["cal_max"] = max((
abs(np.asanyarray(coeff.dataobj).min()),
np.asanyarray(coeff.dataobj).max(),
))
hdr["cal_min"] = - hdr["cal_max"]
hdr["cal_max"] = max(
(
abs(np.asanyarray(coeff.dataobj).min()),
np.asanyarray(coeff.dataobj).max(),
)
)
hdr["cal_min"] = -hdr["cal_max"]
return coeff.__class__(coeff.dataobj, newaff, hdr)

0 comments on commit f437f69

Please sign in to comment.