Skip to content

Commit

Permalink
enh: convert Coefficients2Warp -> ApplyCoeffsField
Browse files Browse the repository at this point in the history
Considering that we have all the ingredients, it is natural to also
resample the target.
If the target is a list of files, then all of them are unwarped.
  • Loading branch information
oesteban committed May 7, 2021
1 parent 08b9df2 commit f2f183b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 68 deletions.
83 changes: 53 additions & 30 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Filtering of :math:`B_0` field mappings with B-Splines."""
from pathlib import Path
from functools import partial
import numpy as np
import nibabel as nb
from nibabel.affines import apply_affine
Expand Down Expand Up @@ -188,58 +189,80 @@ def _run_interface(self, runtime):
return runtime


class _Coefficients2WarpInputSpec(BaseInterfaceInputSpec):
in_target = File(exist=True, mandatory=True, desc="input EPI data to be corrected")
class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
in_target = InputMultiObject(
File(exist=True, mandatory=True, desc="input EPI data to be corrected")
)
in_coeff = InputMultiObject(
File(exists=True),
mandatory=True,
desc="input coefficients, after alignment to the EPI data",
)
ro_time = traits.Float(mandatory=True, desc="EPI readout time (s).")
pe_dir = traits.Enum(
"i",
"i-",
"j",
"j-",
"k",
"k-",
mandatory=True,
desc="the phase-encoding direction corresponding to in_target",
ro_time = InputMultiObject(
traits.Float(mandatory=True, desc="EPI readout time (s).")
)
low_mem = traits.Bool(
False, usedefault=True, desc="perform on low-mem fingerprint regime"
pe_dir = InputMultiObject(
traits.Enum(
"i",
"i-",
"j",
"j-",
"k",
"k-",
mandatory=True,
desc="the phase-encoding direction corresponding to in_target",
)
)


class _Coefficients2WarpOutputSpec(TraitedSpec):
class _ApplyCoeffsFieldOutputSpec(TraitedSpec):
out_corrected = OutputMultiObject(File(exists=True))
out_field = File(exists=True)
out_warp = File(exists=True)
out_warp = OutputMultiObject(File(exists=True))


class Coefficients2Warp(SimpleInterface):
class ApplyCoeffsField(SimpleInterface):
"""Convert a set of B-Spline coefficients to a full displacements map."""

input_spec = _Coefficients2WarpInputSpec
output_spec = _Coefficients2WarpOutputSpec
input_spec = _ApplyCoeffsFieldInputSpec
output_spec = _ApplyCoeffsFieldOutputSpec

def _run_interface(self, runtime):
# Prepare output names
self._results["out_field"] = fname_presuffix(
self.inputs.in_target, suffix="_field", newpath=runtime.cwd
)
self._results["out_warp"] = self._results["out_field"].replace(
"_field.nii", "_xfm.nii"
)
filename = partial(fname_presuffix, newpath=runtime.cwd)

self._results["out_field"] = filename(self.inputs.in_coeff[0], suffix="_field")
self._results["out_warp"] = []
self._results["out_corrected"] = []

xfm = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff]
)
xfm.fit(self.inputs.in_target)
xfm.fit(self.inputs.in_target[0])
xfm.shifts.to_filename(self._results["out_field"])
xfm.to_displacements(
ro_time=self.inputs.ro_time,
pe_dir=self.inputs.pe_dir,
).to_filename(self._results["out_warp"])

n_inputs = len(self.inputs.in_target)
ro_time = self.inputs.ro_time
if len(ro_time) == 1:
ro_time = [ro_time[0]] * n_inputs

pe_dir = self.inputs.pe_dir
if len(pe_dir) == 1:
pe_dir = [pe_dir[0]] * n_inputs

for fname, pe, ro in zip(self.inputs.in_target, pe_dir, ro_time):
xfm.fit(fname)

# Generate warpfield
warp_name = filename(fname, suffix="_xfm")
xfm.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name)
self._results["out_warp"].append(warp_name)

# Generate resampled
out_name = filename(fname, suffix="_unwarped")
xfm.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name)
self._results["out_corrected"].append(out_name)

return runtime


Expand Down
8 changes: 4 additions & 4 deletions sdcflows/interfaces/tests/test_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..bspline import (
bspline_grid,
Coefficients2Warp,
ApplyCoeffsField,
BSplineApprox,
TOPUPCoeffReorient,
_fix_topup_fieldcoeff,
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_bsplines(tmp_path, testnum):

os.chdir(tmp_path)
# Check that we can interpolate the coefficients on a target
test1 = Coefficients2Warp(
test1 = ApplyCoeffsField(
in_target=str(tmp_path / "target.nii.gz"),
in_coeff=str(tmp_path / "coeffs.nii.gz"),
pe_dir="j-",
Expand Down Expand Up @@ -91,8 +91,8 @@ def test_topup_coeffs(tmpdir, testdata_dir):
def test_topup_coeffs_interpolation(tmpdir, testdata_dir):
"""Check that our interpolation is not far away from TOPUP's."""
tmpdir.chdir()
result = Coefficients2Warp(
in_target=str(testdata_dir / "epi.nii.gz"),
result = ApplyCoeffsField(
in_target=[str(testdata_dir / "epi.nii.gz")] * 2,
in_coeff=str(testdata_dir / "topup-coeff-fixed.nii.gz"),
pe_dir="j-",
ro_time=1.0,
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def apply(
voxcoords[pe_axis, ...] += vsm * ro_time

# Prepare data
data = np.asanyarray(spatialimage.dataobj)
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
output_dtype = output_dtype or data.dtype

# Resample
Expand Down
17 changes: 6 additions & 11 deletions sdcflows/workflows/apply/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):
a fast mask calculated from the corrected EPI reference.
"""
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
from ...interfaces.epi import GetReadoutTime
from ...interfaces.bspline import Coefficients2Warp
from ...interfaces.bspline import ApplyCoeffsField
from ..ancillary import init_brainextraction_wf

workflow = Workflow(name=name)
Expand All @@ -65,10 +64,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):

rotime = pe.Node(GetReadoutTime(), name="rotime")
rotime.interface._always_run = debug
resample = pe.Node(Coefficients2Warp(low_mem=debug), name="resample")
unwarp = pe.Node(
ApplyTransforms(dimension=3, interpolation="BSpline"), name="unwarp"
)
resample = pe.Node(ApplyCoeffsField(), name="resample")

brainextraction_wf = init_brainextraction_wf()

# fmt:off
Expand All @@ -79,11 +76,9 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):
("fmap_coeff", "in_coeff")]),
(rotime, resample, [("readout_time", "ro_time"),
("pe_direction", "pe_dir")]),
(inputnode, unwarp, [("distorted", "reference_image"),
("distorted", "input_image")]),
(resample, unwarp, [("out_warp", "transforms")]),
(resample, outputnode, [("out_field", "fieldmap")]),
(unwarp, brainextraction_wf, [("output_image", "inputnode.in_file")]),
(resample, outputnode, [("out_field", "fieldmap"),
("out_warp", "transforms")]),
(resample, brainextraction_wf, [("out_corrected", "inputnode.in_file")]),
(brainextraction_wf, outputnode, [
("outputnode.out_file", "corrected"),
("outputnode.out_mask", "corrected_mask"),
Expand Down
36 changes: 14 additions & 22 deletions sdcflows/workflows/fit/pepolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,36 +165,28 @@ def _getpe(in_meta):
# fmt: on
return workflow

from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
from ...utils.misc import front as _front
from ...interfaces.bspline import Coefficients2Warp
from ...interfaces.bspline import ApplyCoeffsField

coeff2xfm = pe.Node(
Coefficients2Warp(ro_time=1.0, low_mem=sloppy),
name="coeff2xfm",
)
unwarp = pe.Node(
ApplyTransforms(dimension=3, interpolation="BSpline"),
ApplyCoeffsField(ro_time=1.0),
name="unwarp",
)

def _getpe(in_meta):
if isinstance(in_meta, list):
in_meta = in_meta[0]
return in_meta["PhaseEncodingDirection"]
def _getpe(inlist):
if isinstance(inlist, dict):
inlist = [inlist]

return [m["PhaseEncodingDirection"] for m in inlist]

# fmt:off
workflow.connect([
(fix_coeff, coeff2xfm, [("out_coeff", "in_coeff")]),
(flatten, coeff2xfm, [(("out_data", _front), "in_target"),
(("out_meta", _getpe), "pe_dir")]),
(readout_time, coeff2xfm, [(("readout_time", _front), "ro_time")]),
(coeff2xfm, unwarp, [("out_warp", "transforms")]),
(flatten, unwarp, [(("out_data", _front), "reference_image"),
(("out_data", _front), "input_image")]),
(coeff2xfm, outputnode, [("out_warp", "out_warps"),
("out_field", "fmap")]),
(unwarp, merge_corrected, [("output_image", "in_files")]),
(fix_coeff, unwarp, [("out_coeff", "in_coeff")]),
(flatten, unwarp, [("out_data", "in_target"),
(("out_meta", _getpe), "pe_dir")]),
(readout_time, unwarp, [("readout_time", "ro_time")]),
(unwarp, outputnode, [("out_warp", "out_warps"),
("out_field", "fmap")]),
(unwarp, merge_corrected, [("out_corrected", "in_files")]),
])
# fmt:on

Expand Down

0 comments on commit f2f183b

Please sign in to comment.