Skip to content

Commit

Permalink
Merge pull request #200 from nipreps/enh/pepolar-double-check
Browse files Browse the repository at this point in the history
ENH: Double-check conversion from TOPUP to standardized fieldmaps
  • Loading branch information
oesteban authored May 7, 2021
2 parents 169db48 + f2f183b commit 4909f5f
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 53 deletions.
83 changes: 53 additions & 30 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Filtering of :math:`B_0` field mappings with B-Splines."""
from pathlib import Path
from functools import partial
import numpy as np
import nibabel as nb
from nibabel.affines import apply_affine
Expand Down Expand Up @@ -188,58 +189,80 @@ def _run_interface(self, runtime):
return runtime


class _Coefficients2WarpInputSpec(BaseInterfaceInputSpec):
in_target = File(exist=True, mandatory=True, desc="input EPI data to be corrected")
class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
in_target = InputMultiObject(
File(exist=True, mandatory=True, desc="input EPI data to be corrected")
)
in_coeff = InputMultiObject(
File(exists=True),
mandatory=True,
desc="input coefficients, after alignment to the EPI data",
)
ro_time = traits.Float(mandatory=True, desc="EPI readout time (s).")
pe_dir = traits.Enum(
"i",
"i-",
"j",
"j-",
"k",
"k-",
mandatory=True,
desc="the phase-encoding direction corresponding to in_target",
ro_time = InputMultiObject(
traits.Float(mandatory=True, desc="EPI readout time (s).")
)
low_mem = traits.Bool(
False, usedefault=True, desc="perform on low-mem fingerprint regime"
pe_dir = InputMultiObject(
traits.Enum(
"i",
"i-",
"j",
"j-",
"k",
"k-",
mandatory=True,
desc="the phase-encoding direction corresponding to in_target",
)
)


class _Coefficients2WarpOutputSpec(TraitedSpec):
class _ApplyCoeffsFieldOutputSpec(TraitedSpec):
out_corrected = OutputMultiObject(File(exists=True))
out_field = File(exists=True)
out_warp = File(exists=True)
out_warp = OutputMultiObject(File(exists=True))


class Coefficients2Warp(SimpleInterface):
class ApplyCoeffsField(SimpleInterface):
"""Convert a set of B-Spline coefficients to a full displacements map."""

input_spec = _Coefficients2WarpInputSpec
output_spec = _Coefficients2WarpOutputSpec
input_spec = _ApplyCoeffsFieldInputSpec
output_spec = _ApplyCoeffsFieldOutputSpec

def _run_interface(self, runtime):
# Prepare output names
self._results["out_field"] = fname_presuffix(
self.inputs.in_target, suffix="_field", newpath=runtime.cwd
)
self._results["out_warp"] = self._results["out_field"].replace(
"_field.nii", "_xfm.nii"
)
filename = partial(fname_presuffix, newpath=runtime.cwd)

self._results["out_field"] = filename(self.inputs.in_coeff[0], suffix="_field")
self._results["out_warp"] = []
self._results["out_corrected"] = []

xfm = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff]
)
xfm.fit(self.inputs.in_target)
xfm.fit(self.inputs.in_target[0])
xfm.shifts.to_filename(self._results["out_field"])
xfm.to_displacements(
ro_time=self.inputs.ro_time,
pe_dir=self.inputs.pe_dir,
).to_filename(self._results["out_warp"])

n_inputs = len(self.inputs.in_target)
ro_time = self.inputs.ro_time
if len(ro_time) == 1:
ro_time = [ro_time[0]] * n_inputs

pe_dir = self.inputs.pe_dir
if len(pe_dir) == 1:
pe_dir = [pe_dir[0]] * n_inputs

for fname, pe, ro in zip(self.inputs.in_target, pe_dir, ro_time):
xfm.fit(fname)

# Generate warpfield
warp_name = filename(fname, suffix="_xfm")
xfm.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name)
self._results["out_warp"].append(warp_name)

# Generate resampled
out_name = filename(fname, suffix="_unwarped")
xfm.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name)
self._results["out_corrected"].append(out_name)

return runtime


Expand Down
8 changes: 4 additions & 4 deletions sdcflows/interfaces/tests/test_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..bspline import (
bspline_grid,
Coefficients2Warp,
ApplyCoeffsField,
BSplineApprox,
TOPUPCoeffReorient,
_fix_topup_fieldcoeff,
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_bsplines(tmp_path, testnum):

os.chdir(tmp_path)
# Check that we can interpolate the coefficients on a target
test1 = Coefficients2Warp(
test1 = ApplyCoeffsField(
in_target=str(tmp_path / "target.nii.gz"),
in_coeff=str(tmp_path / "coeffs.nii.gz"),
pe_dir="j-",
Expand Down Expand Up @@ -91,8 +91,8 @@ def test_topup_coeffs(tmpdir, testdata_dir):
def test_topup_coeffs_interpolation(tmpdir, testdata_dir):
"""Check that our interpolation is not far away from TOPUP's."""
tmpdir.chdir()
result = Coefficients2Warp(
in_target=str(testdata_dir / "epi.nii.gz"),
result = ApplyCoeffsField(
in_target=[str(testdata_dir / "epi.nii.gz")] * 2,
in_coeff=str(testdata_dir / "topup-coeff-fixed.nii.gz"),
pe_dir="j-",
ro_time=1.0,
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def apply(
voxcoords[pe_axis, ...] += vsm * ro_time

# Prepare data
data = np.asanyarray(spatialimage.dataobj)
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
output_dtype = output_dtype or data.dtype

# Resample
Expand Down
17 changes: 6 additions & 11 deletions sdcflows/workflows/apply/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):
a fast mask calculated from the corrected EPI reference.
"""
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
from ...interfaces.epi import GetReadoutTime
from ...interfaces.bspline import Coefficients2Warp
from ...interfaces.bspline import ApplyCoeffsField
from ..ancillary import init_brainextraction_wf

workflow = Workflow(name=name)
Expand All @@ -65,10 +64,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):

rotime = pe.Node(GetReadoutTime(), name="rotime")
rotime.interface._always_run = debug
resample = pe.Node(Coefficients2Warp(low_mem=debug), name="resample")
unwarp = pe.Node(
ApplyTransforms(dimension=3, interpolation="BSpline"), name="unwarp"
)
resample = pe.Node(ApplyCoeffsField(), name="resample")

brainextraction_wf = init_brainextraction_wf()

# fmt:off
Expand All @@ -79,11 +76,9 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):
("fmap_coeff", "in_coeff")]),
(rotime, resample, [("readout_time", "ro_time"),
("pe_direction", "pe_dir")]),
(inputnode, unwarp, [("distorted", "reference_image"),
("distorted", "input_image")]),
(resample, unwarp, [("out_warp", "transforms")]),
(resample, outputnode, [("out_field", "fieldmap")]),
(unwarp, brainextraction_wf, [("output_image", "inputnode.in_file")]),
(resample, outputnode, [("out_field", "fieldmap"),
("out_warp", "transforms")]),
(resample, brainextraction_wf, [("out_corrected", "inputnode.in_file")]),
(brainextraction_wf, outputnode, [
("outputnode.out_file", "corrected"),
("outputnode.out_mask", "corrected_mask"),
Expand Down
48 changes: 41 additions & 7 deletions sdcflows/workflows/fit/pepolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
echo-planar imaging (EPI) references """


def init_topup_wf(omp_nthreads=1, sloppy=False, debug=False, name="pepolar_estimate_wf"):
def init_topup_wf(
omp_nthreads=1, sloppy=False, debug=False, name="pepolar_estimate_wf"
):
"""
Create the PEPOLAR field estimation workflow based on FSL's ``topup``.
Expand Down Expand Up @@ -112,7 +114,7 @@ def init_topup_wf(omp_nthreads=1, sloppy=False, debug=False, name="pepolar_estim

topup = pe.Node(
TOPUP(
config=_pkg_fname("sdcflows", f"data/flirtsch/b02b0{'_quick' * debug}.cnf")
config=_pkg_fname("sdcflows", f"data/flirtsch/b02b0{'_quick' * sloppy}.cnf")
),
name="topup",
)
Expand Down Expand Up @@ -144,18 +146,50 @@ def _getpe(in_meta):
(flatten, fix_coeff, [(("out_data", _front), "fmap_ref"),
(("out_meta", _getpe), "pe_dir")]),
(topup, fix_coeff, [("out_fieldcoef", "in_coeff")]),
(topup, merge_corrected, [("out_corrected", "in_files")]),
(topup, outputnode, [("out_field", "fmap"),
("out_jacs", "jacobians"),
("out_mats", "xfms"),
("out_warps", "out_warps")]),
(topup, outputnode, [("out_jacs", "jacobians"),
("out_mats", "xfms")]),
(merge_corrected, brainextraction_wf, [("out_avg", "inputnode.in_file")]),
(merge_corrected, outputnode, [("out_avg", "fmap_ref")]),
(brainextraction_wf, outputnode, [("outputnode.out_mask", "fmap_mask")]),
(fix_coeff, outputnode, [("out_coeff", "fmap_coeff")]),
])
# fmt: on

if not debug:
# fmt: off
workflow.connect([
(topup, merge_corrected, [("out_corrected", "in_files")]),
(topup, outputnode, [("out_field", "fmap"),
("out_warps", "out_warps")]),
])
# fmt: on
return workflow

from ...interfaces.bspline import ApplyCoeffsField

unwarp = pe.Node(
ApplyCoeffsField(ro_time=1.0),
name="unwarp",
)

def _getpe(inlist):
if isinstance(inlist, dict):
inlist = [inlist]

return [m["PhaseEncodingDirection"] for m in inlist]

# fmt:off
workflow.connect([
(fix_coeff, unwarp, [("out_coeff", "in_coeff")]),
(flatten, unwarp, [("out_data", "in_target"),
(("out_meta", _getpe), "pe_dir")]),
(readout_time, unwarp, [("readout_time", "ro_time")]),
(unwarp, outputnode, [("out_warp", "out_warps"),
("out_field", "fmap")]),
(unwarp, merge_corrected, [("out_corrected", "in_files")]),
])
# fmt:on

return workflow


Expand Down

0 comments on commit 4909f5f

Please sign in to comment.