diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 2070623084..8ee2eec6a8 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -2,6 +2,7 @@ # vi: set ft=python sts=4 ts=4 sw=4 et: """Filtering of :math:`B_0` field mappings with B-Splines.""" from pathlib import Path +from functools import partial import numpy as np import nibabel as nb from nibabel.affines import apply_affine @@ -188,58 +189,80 @@ def _run_interface(self, runtime): return runtime -class _Coefficients2WarpInputSpec(BaseInterfaceInputSpec): - in_target = File(exist=True, mandatory=True, desc="input EPI data to be corrected") +class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): + in_target = InputMultiObject( + File(exist=True, mandatory=True, desc="input EPI data to be corrected") + ) in_coeff = InputMultiObject( File(exists=True), mandatory=True, desc="input coefficients, after alignment to the EPI data", ) - ro_time = traits.Float(mandatory=True, desc="EPI readout time (s).") - pe_dir = traits.Enum( - "i", - "i-", - "j", - "j-", - "k", - "k-", - mandatory=True, - desc="the phase-encoding direction corresponding to in_target", + ro_time = InputMultiObject( + traits.Float(mandatory=True, desc="EPI readout time (s).") ) - low_mem = traits.Bool( - False, usedefault=True, desc="perform on low-mem fingerprint regime" + pe_dir = InputMultiObject( + traits.Enum( + "i", + "i-", + "j", + "j-", + "k", + "k-", + mandatory=True, + desc="the phase-encoding direction corresponding to in_target", + ) ) -class _Coefficients2WarpOutputSpec(TraitedSpec): +class _ApplyCoeffsFieldOutputSpec(TraitedSpec): + out_corrected = OutputMultiObject(File(exists=True)) out_field = File(exists=True) - out_warp = File(exists=True) + out_warp = OutputMultiObject(File(exists=True)) -class Coefficients2Warp(SimpleInterface): +class ApplyCoeffsField(SimpleInterface): """Convert a set of B-Spline coefficients to a full displacements map.""" - input_spec = _Coefficients2WarpInputSpec - output_spec = _Coefficients2WarpOutputSpec + input_spec = _ApplyCoeffsFieldInputSpec + output_spec = _ApplyCoeffsFieldOutputSpec def _run_interface(self, runtime): # Prepare output names - self._results["out_field"] = fname_presuffix( - self.inputs.in_target, suffix="_field", newpath=runtime.cwd - ) - self._results["out_warp"] = self._results["out_field"].replace( - "_field.nii", "_xfm.nii" - ) + filename = partial(fname_presuffix, newpath=runtime.cwd) + + self._results["out_field"] = filename(self.inputs.in_coeff[0], suffix="_field") + self._results["out_warp"] = [] + self._results["out_corrected"] = [] xfm = B0FieldTransform( coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] ) - xfm.fit(self.inputs.in_target) + xfm.fit(self.inputs.in_target[0]) xfm.shifts.to_filename(self._results["out_field"]) - xfm.to_displacements( - ro_time=self.inputs.ro_time, - pe_dir=self.inputs.pe_dir, - ).to_filename(self._results["out_warp"]) + + n_inputs = len(self.inputs.in_target) + ro_time = self.inputs.ro_time + if len(ro_time) == 1: + ro_time = [ro_time[0]] * n_inputs + + pe_dir = self.inputs.pe_dir + if len(pe_dir) == 1: + pe_dir = [pe_dir[0]] * n_inputs + + for fname, pe, ro in zip(self.inputs.in_target, pe_dir, ro_time): + xfm.fit(fname) + + # Generate warpfield + warp_name = filename(fname, suffix="_xfm") + xfm.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name) + self._results["out_warp"].append(warp_name) + + # Generate resampled + out_name = filename(fname, suffix="_unwarped") + xfm.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name) + self._results["out_corrected"].append(out_name) + return runtime diff --git a/sdcflows/interfaces/tests/test_bspline.py b/sdcflows/interfaces/tests/test_bspline.py index 8233000e88..887ccedc46 100644 --- a/sdcflows/interfaces/tests/test_bspline.py +++ b/sdcflows/interfaces/tests/test_bspline.py @@ -6,7 +6,7 @@ from ..bspline import ( bspline_grid, - Coefficients2Warp, + ApplyCoeffsField, BSplineApprox, TOPUPCoeffReorient, _fix_topup_fieldcoeff, @@ -40,7 +40,7 @@ def test_bsplines(tmp_path, testnum): os.chdir(tmp_path) # Check that we can interpolate the coefficients on a target - test1 = Coefficients2Warp( + test1 = ApplyCoeffsField( in_target=str(tmp_path / "target.nii.gz"), in_coeff=str(tmp_path / "coeffs.nii.gz"), pe_dir="j-", @@ -91,8 +91,8 @@ def test_topup_coeffs(tmpdir, testdata_dir): def test_topup_coeffs_interpolation(tmpdir, testdata_dir): """Check that our interpolation is not far away from TOPUP's.""" tmpdir.chdir() - result = Coefficients2Warp( - in_target=str(testdata_dir / "epi.nii.gz"), + result = ApplyCoeffsField( + in_target=[str(testdata_dir / "epi.nii.gz")] * 2, in_coeff=str(testdata_dir / "topup-coeff-fixed.nii.gz"), pe_dir="j-", ro_time=1.0, diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 7cc6551071..bec071e936 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -118,7 +118,7 @@ def apply( voxcoords[pe_axis, ...] += vsm * ro_time # Prepare data - data = np.asanyarray(spatialimage.dataobj) + data = np.squeeze(np.asanyarray(spatialimage.dataobj)) output_dtype = output_dtype or data.dtype # Resample diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 7a9ca5e311..78ccaa1a77 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -48,9 +48,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): a fast mask calculated from the corrected EPI reference. """ - from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms from ...interfaces.epi import GetReadoutTime - from ...interfaces.bspline import Coefficients2Warp + from ...interfaces.bspline import ApplyCoeffsField from ..ancillary import init_brainextraction_wf workflow = Workflow(name=name) @@ -65,10 +64,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): rotime = pe.Node(GetReadoutTime(), name="rotime") rotime.interface._always_run = debug - resample = pe.Node(Coefficients2Warp(low_mem=debug), name="resample") - unwarp = pe.Node( - ApplyTransforms(dimension=3, interpolation="BSpline"), name="unwarp" - ) + resample = pe.Node(ApplyCoeffsField(), name="resample") + brainextraction_wf = init_brainextraction_wf() # fmt:off @@ -79,11 +76,9 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): ("fmap_coeff", "in_coeff")]), (rotime, resample, [("readout_time", "ro_time"), ("pe_direction", "pe_dir")]), - (inputnode, unwarp, [("distorted", "reference_image"), - ("distorted", "input_image")]), - (resample, unwarp, [("out_warp", "transforms")]), - (resample, outputnode, [("out_field", "fieldmap")]), - (unwarp, brainextraction_wf, [("output_image", "inputnode.in_file")]), + (resample, outputnode, [("out_field", "fieldmap"), + ("out_warp", "transforms")]), + (resample, brainextraction_wf, [("out_corrected", "inputnode.in_file")]), (brainextraction_wf, outputnode, [ ("outputnode.out_file", "corrected"), ("outputnode.out_mask", "corrected_mask"), diff --git a/sdcflows/workflows/fit/pepolar.py b/sdcflows/workflows/fit/pepolar.py index 54966a63fd..6aa836fc92 100644 --- a/sdcflows/workflows/fit/pepolar.py +++ b/sdcflows/workflows/fit/pepolar.py @@ -165,36 +165,28 @@ def _getpe(in_meta): # fmt: on return workflow - from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms - from ...utils.misc import front as _front - from ...interfaces.bspline import Coefficients2Warp + from ...interfaces.bspline import ApplyCoeffsField - coeff2xfm = pe.Node( - Coefficients2Warp(ro_time=1.0, low_mem=sloppy), - name="coeff2xfm", - ) unwarp = pe.Node( - ApplyTransforms(dimension=3, interpolation="BSpline"), + ApplyCoeffsField(ro_time=1.0), name="unwarp", ) - def _getpe(in_meta): - if isinstance(in_meta, list): - in_meta = in_meta[0] - return in_meta["PhaseEncodingDirection"] + def _getpe(inlist): + if isinstance(inlist, dict): + inlist = [inlist] + + return [m["PhaseEncodingDirection"] for m in inlist] # fmt:off workflow.connect([ - (fix_coeff, coeff2xfm, [("out_coeff", "in_coeff")]), - (flatten, coeff2xfm, [(("out_data", _front), "in_target"), - (("out_meta", _getpe), "pe_dir")]), - (readout_time, coeff2xfm, [(("readout_time", _front), "ro_time")]), - (coeff2xfm, unwarp, [("out_warp", "transforms")]), - (flatten, unwarp, [(("out_data", _front), "reference_image"), - (("out_data", _front), "input_image")]), - (coeff2xfm, outputnode, [("out_warp", "out_warps"), - ("out_field", "fmap")]), - (unwarp, merge_corrected, [("output_image", "in_files")]), + (fix_coeff, unwarp, [("out_coeff", "in_coeff")]), + (flatten, unwarp, [("out_data", "in_target"), + (("out_meta", _getpe), "pe_dir")]), + (readout_time, unwarp, [("readout_time", "ro_time")]), + (unwarp, outputnode, [("out_warp", "out_warps"), + ("out_field", "fmap")]), + (unwarp, merge_corrected, [("out_corrected", "in_files")]), ]) # fmt:on