diff --git a/mriqc/interfaces/diffusion.py b/mriqc/interfaces/diffusion.py index 91aae38bc..58ba5e4c5 100644 --- a/mriqc/interfaces/diffusion.py +++ b/mriqc/interfaces/diffusion.py @@ -21,15 +21,20 @@ # https://www.nipreps.org/community/licensing/ # """Interfaces for manipulating DWI data.""" +from __future__ import annotations + import nibabel as nb import numpy as np +import scipy.ndimage as nd from nipype.interfaces.base import ( BaseInterfaceInputSpec as _BaseInterfaceInputSpec, ) from nipype.interfaces.base import ( File, + InputMultiObject, OutputMultiObject, SimpleInterface, + TraitedSpec, isdefined, traits, ) @@ -41,6 +46,19 @@ from sklearn.cluster import KMeans from sklearn.model_selection import GridSearchCV +__all__ = ( + 'CCSegmentation', + 'CorrectSignalDrift', + 'DiffusionQC', + 'DipyDTI', + 'ExtractOrientations', + 'FilterShells', + 'NumberOfShells', + 'ReadDWIMetadata', + 'SplitShells', + 'WeightedStat', +) + class _ReadDWIMetadataOutputSpec(_ReadSidecarJSONOutputSpec): out_bvec_file = File(desc='corresponding bvec file') @@ -214,20 +232,20 @@ def _run_interface(self, runtime): return runtime -class _ExtractB0InputSpec(_BaseInterfaceInputSpec): +class _ExtractOrientationsInputSpec(_BaseInterfaceInputSpec): in_file = File(exists=True, mandatory=True, desc='dwi file') b0_ixs = traits.List(traits.Int, mandatory=True, desc='Index of b0s') -class _ExtractB0OutputSpec(_TraitedSpec): +class _ExtractOrientationsOutputSpec(_TraitedSpec): out_file = File(exists=True, desc='output b0 file') -class ExtractB0(SimpleInterface): +class ExtractOrientations(SimpleInterface): """Extract all b=0 volumes from a dwi series.""" - input_spec = _ExtractB0InputSpec - output_spec = _ExtractB0OutputSpec + input_spec = _ExtractOrientationsInputSpec + output_spec = _ExtractOrientationsOutputSpec def _run_interface(self, runtime): from nipype.utils.filemanip import fname_presuffix @@ -238,9 +256,15 @@ def _run_interface(self, runtime): newpath=runtime.cwd, ) - self._results['out_file'] = _extract_b0( - self.inputs.in_file, self.inputs.b0_ixs, out_path=out_file - ) + self._results['out_file'] = out_file + + img = nb.load(self.inputs.in_file) + bzeros = np.squeeze(np.asanyarray(img.dataobj)[..., self.inputs.b0_ixs]) + + hdr = img.header.copy() + hdr.set_data_shape(bzeros.shape) + hdr.set_xyzt_units('mm') + nb.Nifti1Image(bzeros, img.affine, hdr).to_filename(out_file) return runtime @@ -424,10 +448,17 @@ class _DipyDTIInputSpec(_BaseInterfaceInputSpec): brainmask = File(exists=True, desc='brain mask file') free_water_model = traits.Bool(False, usedefault=True, desc='use free water model') b_threshold = traits.Float(1100, usedefault=True, desc='use only inner shells of the data') + decimals = traits.Int(3, usedefault=True, desc='round output maps for reliability') class _DipyDTIOutputSpec(_TraitedSpec): out_fa = File(exists=True, desc='output FA file') + out_fa_nans = File(exists=True, desc='binary mask of NaN values in the "raw" FA map') + out_fa_degenerate = File( + exists=True, + desc='binary mask of values outside [0, 1] in the "raw" FA map', + ) + out_cfa = File(exists=True, desc='output color FA file') out_md = File(exists=True, desc='output MD file') @@ -439,7 +470,7 @@ class DipyDTI(SimpleInterface): def _run_interface(self, runtime): from dipy.core.gradients import gradient_table_from_bvals_bvecs - from dipy.reconst.dti import TensorModel + from dipy.reconst.dti import TensorModel, color_fa, fractional_anisotropy from dipy.reconst.fwdti import FreeWaterTensorModel from nipype.utils.filemanip import fname_presuffix @@ -468,8 +499,14 @@ def _run_interface(self, runtime): ) # Extract the FA - fa_data = np.array(fwdtifit.fa, dtype='float32') - fa_data[np.isnan(fa_data)] = 0 + fa_data = fractional_anisotropy(fwdtifit.evals) + fa_nan_msk = np.isnan(fa_data) + fa_data[fa_nan_msk] = 0 + + # Round for stability + fa_data = np.round(fa_data, self.inputs.decimals) + degenerate_msk = (fa_data < 0) | (fa_data > 1.0) + # Clamp the FA to remove degenerate fa_data = np.clip(fa_data, 0, 1) fa_nii = nb.Nifti1Image( @@ -480,8 +517,6 @@ def _run_interface(self, runtime): fa_nii.header.set_xyzt_units('mm') fa_nii.header.set_intent('estimate', name='Fractional Anisotropy (FA)') - fa_nii.header['cal_max'] = 1.0 - fa_nii.header['cal_min'] = 0.0 self._results['out_fa'] = fname_presuffix( self.inputs.in_file, @@ -491,6 +526,66 @@ def _run_interface(self, runtime): fa_nii.to_filename(self._results['out_fa']) + # Write out degenerate and nans masks + fa_nan_nii = nb.Nifti1Image( + fa_nan_msk.astype(np.uint8), + img.affine, + None, + ) + + fa_nan_nii.header.set_xyzt_units('mm') + fa_nan_nii.header.set_intent('estimate', name='NaNs in the FA map mask') + fa_nan_nii.header['cal_max'] = 1 + fa_nan_nii.header['cal_min'] = 0 + + self._results['out_fa_nans'] = fname_presuffix( + self.inputs.in_file, + suffix='desc-fanans_mask', + newpath=runtime.cwd, + ) + fa_nan_nii.to_filename(self._results['out_fa_nans']) + + fa_degenerate_nii = nb.Nifti1Image( + degenerate_msk.astype(np.uint8), + img.affine, + None, + ) + + fa_degenerate_nii.header.set_xyzt_units('mm') + fa_degenerate_nii.header.set_intent( + 'estimate', + name='degenerate vectors in the FA map mask' + ) + fa_degenerate_nii.header['cal_max'] = 1 + fa_degenerate_nii.header['cal_min'] = 0 + + self._results['out_fa_degenerate'] = fname_presuffix( + self.inputs.in_file, + suffix='desc-fadegenerate_mask', + newpath=runtime.cwd, + ) + fa_degenerate_nii.to_filename(self._results['out_fa_degenerate']) + + # Extract the color FA + cfa_data = color_fa(fa_data, fwdtifit.evecs) + cfa_nii = nb.Nifti1Image( + cfa_data, + img.affine, + None, + ) + + cfa_nii.header.set_xyzt_units('mm') + cfa_nii.header.set_intent('estimate', name='Fractional Anisotropy (FA)') + cfa_nii.header['cal_max'] = 1.0 + cfa_nii.header['cal_min'] = 0.0 + + self._results['out_cfa'] = fname_presuffix( + self.inputs.in_file, + suffix='cfa', + newpath=runtime.cwd, + ) + cfa_nii.to_filename(self._results['out_cfa']) + # Extract the AD self._results['out_md'] = fname_presuffix( self.inputs.in_file, @@ -511,6 +606,182 @@ def _run_interface(self, runtime): return runtime +class _DiffusionQCInputSpec(_BaseInterfaceInputSpec): + in_b0 = File(exists=True, mandatory=True, desc='input b=0 average') + in_shells = InputMultiObject( + File(exists=True), + mandatory=True, + desc='DWI data after HMC and split by shells (indexed by in_bval)' + ) + in_bvec = File(exists=True, mandatory=True, desc='input motion corrected file') + in_bval = traits.List(traits.Float, minlen=1, desc='b-values') + in_bvec = traits.List( + traits.Tuple(traits.Float, traits.Float, traits.Float), + minlen=1, + desc='b-vectors', + ) + in_fa = File(exists=True, mandatory=True, desc='input FA map') + in_md = File(exists=True, mandatory=True, desc='input MD map') + in_brainmask = File(exists=True, mandatory=True, desc='input probabilistic brain mask') + in_wmmask = File(exists=True, mandatory=True, desc='input probabilistic brain mask') + direction = traits.Enum( + 'all', + 'x', + 'y', + '-x', + '-y', + usedefault=True, + desc='direction for GSR computation', + ) + in_fwhm = traits.List(traits.Float, mandatory=True, desc='smoothness estimated with AFNI') + + +class _DiffusionQCOutputSpec(TraitedSpec): + fber = traits.Float + efc = traits.Float + snr = traits.Float + gsr = traits.Dict + tsnr = traits.Float + fd = traits.Dict + fwhm = traits.Dict(desc='full width half-maximum measure') + size = traits.Dict + spacing = traits.Dict + summary = traits.Dict + + out_qc = traits.Dict(desc='output flattened dictionary with all measures') + + +class DiffusionQC(SimpleInterface): + """Computes :abbr:`QC (Quality Control)` measures on the input DWI EPI scan.""" + + input_spec = _DiffusionQCInputSpec + output_spec = _DiffusionQCOutputSpec + + def _run_interface(self, runtime): + from mriqc.qc import anatomical as aqc + from mriqc.qc import diffusion as dqc + + # Get the mean EPI data and get it ready + b0nii = nb.load(self.inputs.in_b0) + b0data = np.round( + np.nan_to_num(np.asanyarray(b0nii.dataobj)), + 2, + ) + b0data[b0data < 0] = 0 + + # Get the FA data and get it ready + fanii = nb.load(self.inputs.in_fa) + fadata = np.round( + np.nan_to_num(np.asanyarray(fanii.dataobj)), + 2, + ) + + # Get EPI data (with hmc done) and get it ready + hmcnii = nb.load(self.inputs.in_hmc) + hmcdata = np.round( + np.nan_to_num(np.asanyarray(hmcnii.dataobj)), + 2, + ) + hmcdata[hmcdata < 0] = 0 + + # Get brain mask data + msknii = nb.load(self.inputs.in_mask) + mskdata = np.round( # Protect the thresholding with a rounding for stability + np.asanyarray(msknii.dataobj), + 1, + ) > 0 + if np.sum(mskdata) < 100: + raise RuntimeError( + 'Detected less than 100 voxels belonging to the brain mask. ' + 'MRIQC failed to process this dataset.' + ) + + # Summary stats + rois = { + 'fg': mskdata, + 'bg': 1.0 - mskdata, + } + stats = aqc.summary_stats(b0data, rois) + self._results['summary'] = stats + + self._results['cc_snr'] = dqc.cc_snr( + fadata + ) + + return runtime + + +class _CCSegmentationInputSpec(_BaseInterfaceInputSpec): + in_fa = File(exists=True, mandatory=True, desc='fractional anisotropy (FA) file') + in_cfa = File(exists=True, mandatory=True, desc='color FA file') + minmax_R = traits.Tuple(traits.Float, traits.Float, value=(0.6, 1.0), usedefault=True, + desc='minimum and maximum for red') + minmax_G = traits.Tuple(traits.Float, traits.Float, value=(0.0, 0.1), usedefault=True, + desc='minimum and maximum for green') + minmax_B = traits.Tuple(traits.Float, traits.Float, value=(0.0, 0.1), usedefault=True, + desc='minimum and maximum for blue') + wm_threshold = traits.Float(0.6, usedefault=True, desc='WM segmentation threshold') + opening_wm_mask = traits.Int( + 2, + usedefault=True, + desc='iterations of binary opening of the WM mask with a ball structuring element') + + +class _CCSegmentationOutputSpec(_TraitedSpec): + out_mask = File(exists=True, desc='output mask of the corpus callosum') + + +class CCSegmentation(SimpleInterface): + """Computes :abbr:`QC (Quality Control)` measures on the input DWI EPI scan.""" + + input_spec = _CCSegmentationInputSpec + output_spec = _CCSegmentationOutputSpec + + def _run_interface(self, runtime): + from skimage.measure import label + + self._results['out_mask'] = fname_presuffix( + self.inputs.in_cfa, + suffix='ccmask', + newpath=runtime.cwd, + ) + + fa_nii = nb.load(self.inputs.in_fa) + fa_data = np.round(fa_nii.get_fdata(dtype='float32'), 4) + fa_labels = label((fa_data > self.inputs.wm_threshold).astype(np.uint8)) + wm_mask = fa_labels == np.argmax(np.bincount(fa_labels.flat)[1:]) + 1 + + if self.inputs.erode_wm_mask > 0: + struct = nd.generate_binary_structure(wm_mask.ndim, wm_mask.ndim - 1) + # Perform an opening operation on the background data. + wm_mask = nd.binary_opening( + wm_mask, + structure=struct, + iterations=self.inputs.erode_wm_mask, + ) + + cc_mask = segment_corpus_callosum( + in_cfa=nb.load(self.inputs.in_cfa).get_fdata(dtype='float32'), + mask=wm_mask, + threshold=( + self.inputs.minmax_R + + self.inputs.minmax_G + + self.inputs.minmax_B + ), + ) + cc_mask_nii = nb.Nifti1Image( + cc_mask.astype(np.uint8), + fa_nii.affine, + None, + ) + cc_mask_nii.header.set_xyzt_units('mm') + cc_mask_nii.header.set_intent('estimate', name='corpus callosum mask') + cc_mask_nii.header['cal_max'] = 1 + cc_mask_nii.header['cal_min'] = 0 + cc_mask_nii.to_filename(self._results['out_mask']) + return runtime + + def _rms(estimator, X): """ Callable to pass to GridSearchCV that will calculate a distance score. @@ -528,20 +799,143 @@ def _rms(estimator, X): return -np.sqrt(distance**2).sum() -def _extract_b0(in_file, b0_ixs, out_path=None): - """Extract the *b0* volumes from a DWI dataset.""" - if out_path is None: - out_path = fname_presuffix(in_file, suffix='_b0') +def _exp_func(t, A, K, C): + return A * np.exp(K * t) + C - img = nb.load(in_file) - bzeros = np.squeeze(np.asanyarray(img.dataobj)[..., b0_ixs]) - hdr = img.header.copy() - hdr.set_data_shape(bzeros.shape) - hdr.set_xyzt_units('mm') - nb.Nifti1Image(bzeros, img.affine, hdr).to_filename(out_path) - return out_path +def segment_corpus_callosum( + in_cfa: np.ndarray, + mask: np.ndarray, + threshold: tuple[float, float, float, float, float, float] = ( + 0.6, + 1, + 0, + 0.1, + 0, + 0.1, + ), +) -> tuple[np.ndarray, np.ndarray]: + """ + Segments the corpus callosum (CC) from a color FA map. + + Parameters + ---------- + in_cfa : :obj:`~numpy.ndarray` + The color FA (cFA) map. + mask : :obj:`~numpy.ndarray` (bool, 3D) + A white matter mask used to define the initial bounding box. + threshold : :obj:`tuple`, optional + An iterable that defines the minimum and maximum values to use for + the thresholding of the cFA. Values are specified as + (R_min, R_max, G_min, G_max, B_min, B_max). + + Returns + ------- + cc_mask: :obj:`~numpy.ndarray` + The final binary mask of the segmented CC. + + Notes + ----- + This implementation was derived from + :obj:`dipy.segment.mask.segment_from_cfa`. + The CC mask is then cleaned-up for spurious off voxels with + :obj:`dipy.segment.mask.clean_cc_mask` + """ + from dipy.segment.mask import bounding_box, clean_cc_mask + + # Prepare a bounding box of the CC + cc_box = np.zeros_like(mask, dtype=bool) + mins, maxs = bounding_box(mask) # mask needs to be volume + mins = np.array(mins) + maxs = np.array(maxs) + diff = (maxs - mins) // 4 + bounds_min = mins + diff + bounds_max = maxs - diff + cc_box[ + bounds_min[0]:bounds_max[0], + bounds_min[1]:bounds_max[1], + bounds_min[2]:bounds_max[2] + ] = True + + include = ( + (in_cfa >= threshold[0::2]) + & (in_cfa <= threshold[1::2]) + & cc_box[..., None] + ) + cc_mask = clean_cc_mask(np.all(include, axis=-1)) + return cc_mask -def _exp_func(t, A, K, C): - return A * np.exp(K * t) + C + +def get_spike_mask( + data: np.ndarray, + z_threshold: float = 3.0, + grouping_vals: np.ndarray | None = None, + bmag: int | None = None, +) -> np.ndarray: + """ + Creates a binary mask classifying voxels in the data array as spike or non-spike. + + This function identifies voxels with signal intensities exceeding a threshold based + on standard deviations above the mean. The threshold can be applied globally to + the entire data array, or it can be calculated for groups of voxels defined by + the ``grouping_vals`` parameter. + + Parameters + ---------- + data : :obj:`~numpy.ndarray` + The data array to be thresholded. + z_threshold : :obj:`float`, optional (default=3.0) + The number of standard deviations to use above the mean as the threshold + multiplier. + grouping_vals : :obj:`~numpy.ndarray`, optional + If provided, this array is used to group voxels for thresholding. Voxels + with the same value in ``grouping_vals`` are considered to belong to the same + group. The threshold will be calculated independently for each group. + - If ``grouping_vals`` has the same shape as ``data`` (4D), it is assumed to be + a mask where each voxel value indicates the group it belongs to. + - If ``grouping_vals`` has a 3D shape, it is assumed to represent b-values + corresponding to each voxel in the 4D ``data`` array. In this case, voxels + with the same b-value are grouped together. + bmag : int, optional + The order of magnitude for b-value rounding (used only if + ``grouping_vals`` is provided as b-values). Default: None (derived from max b-value). + + Returns: + ------- + spike_mask : :obj:`~numpy.ndarray` + A binary mask where ``True`` values indicate voxels classified as spikes and + ``False`` values indicate non-spikes. The mask has the same shape as the input + data array. + + """ + from dipy.core.gradients import round_bvals, unique_bvals_magnitude + + if grouping_vals is None: + threshold = np.round((z_threshold * np.std(data)) + np.mean(data), 3) + spike_mask = np.round(data, 3) > threshold + return spike_mask + + threshold_mask = np.zeros(data.shape) + + rounded_grouping_vals = round_bvals(grouping_vals, bmag) + gvals = unique_bvals_magnitude(grouping_vals, bmag) + + if grouping_vals.shape == data.shape: + for gval in gvals: + gval_data = data[rounded_grouping_vals == gval] + gval_threshold = ((z_threshold * np.std(gval_data)) + + np.mean(gval_data)) + threshold_mask[rounded_grouping_vals == gval] = ( + gval_threshold * np.ones(gval_data.shape)) + else: + for gval in gvals: + gval_data = data[..., rounded_grouping_vals == gval] + gval_threshold = ((z_threshold * np.std(gval_data)) + + np.mean(gval_data)) + threshold_mask[..., rounded_grouping_vals == gval] = ( + gval_threshold * np.ones(gval_data.shape)) + + spike_mask = data > threshold_mask + + return spike_mask diff --git a/mriqc/workflows/diffusion/base.py b/mriqc/workflows/diffusion/base.py index 74e4a67ce..226e16053 100644 --- a/mriqc/workflows/diffusion/base.py +++ b/mriqc/workflows/diffusion/base.py @@ -72,7 +72,7 @@ def dmri_qc_workflow(name='dwiMRIQC'): from mriqc.interfaces.diffusion import ( CorrectSignalDrift, DipyDTI, - ExtractB0, + ExtractOrientations, FilterShells, NumberOfShells, ReadDWIMetadata, @@ -126,7 +126,7 @@ def dmri_qc_workflow(name='dwiMRIQC'): # 1. Read metadata & bvec/bval, estimate number of shells, extract and split B0s meta = pe.Node(ReadDWIMetadata(index_db=config.execution.bids_database_dir), name='metadata') shells = pe.Node(NumberOfShells(), name='shells') - get_shells = pe.MapNode(ExtractB0(), name='get_shells', iterfield=['b0_ixs']) + get_shells = pe.MapNode(ExtractOrientations(), name='get_shells', iterfield=['b0_ixs']) hmc_shells = pe.MapNode( Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'), name='hmc_shells',