diff --git a/mriqc/interfaces/diffusion.py b/mriqc/interfaces/diffusion.py index 91aae38bc..33147d9c7 100644 --- a/mriqc/interfaces/diffusion.py +++ b/mriqc/interfaces/diffusion.py @@ -21,15 +21,20 @@ # https://www.nipreps.org/community/licensing/ # """Interfaces for manipulating DWI data.""" +from __future__ import annotations + import nibabel as nb import numpy as np +import scipy.ndimage as nd from nipype.interfaces.base import ( BaseInterfaceInputSpec as _BaseInterfaceInputSpec, ) from nipype.interfaces.base import ( File, + InputMultiObject, OutputMultiObject, SimpleInterface, + TraitedSpec, isdefined, traits, ) @@ -41,6 +46,203 @@ from sklearn.cluster import KMeans from sklearn.model_selection import GridSearchCV +from mriqc.utils.misc import _flatten_dict + + +__all__ = ( + 'CCSegmentation', + 'CorrectSignalDrift', + 'DiffusionQC', + 'DipyDTI', + 'ExtractOrientations', + 'FilterShells', + 'NumberOfShells', + 'ReadDWIMetadata', + 'SplitShells', + 'WeightedStat', +) + + +FD_THRESHOLD = 0.2 + + +class _DiffusionQCInputSpec(_BaseInterfaceInputSpec): + in_file = File(exists=True, mandatory=True, desc='original EPI 4D file') + in_b0 = File(exists=True, mandatory=True, desc='input b=0 average') + in_shells = InputMultiObject( + File(exists=True), + mandatory=True, + desc='DWI data after HMC and split by shells (indexed by in_bval)' + ) + in_bval = traits.List( + traits.Float, + minlen=1, + mandatory=True, + desc='list of unique b-values (one per shell), ordered by growing intensity', + ) + in_bvec = traits.List( + traits.List( + traits.Tuple(traits.Float, traits.Float, traits.Float), + minlen=1, + ), + mandatory=True, + minlen=1, + desc='a list of shell-wise splits of b-vectors lists -- first list are b=0', + ) + in_fa = File(exists=True, mandatory=True, desc='input FA map') + in_fa_nans = File(exists=True, mandatory=True, + desc='binary mask of NaN values in the "raw" FA map') + in_fa_degenerate = File( + exists=True, + mandatory=True, + desc='binary mask of values outside [0, 1] in the "raw" FA map', + ) + in_cfa = File(exists=True, mandatory=True, desc='output color FA file') + in_md = File(exists=True, mandatory=True, desc='input MD map') + brain_mask = File(exists=True, mandatory=True, desc='input probabilistic brain mask') + wm_mask = File(exists=True, mandatory=True, desc='input probabilistic white-matter mask') + cc_mask = File(exists=True, mandatory=True, desc='input probabilistic white-matter mask') + direction = traits.Enum( + 'all', + 'x', + 'y', + '-x', + '-y', + usedefault=True, + desc='direction for GSR computation', + ) + in_fd = File( + exists=True, + mandatory=True, + desc='motion parameters for FD computation', + ) + fd_thres = traits.Float( + FD_THRESHOLD, + usedefault=True, + desc='FD threshold for orientation exclusion based on head motion' + ) + in_fwhm = traits.List(traits.Float, desc='smoothness estimated with AFNI') + + +class _DiffusionQCOutputSpec(TraitedSpec): + cc_snr = traits.Dict + efc = traits.Dict + fber = traits.Dict + fd = traits.Dict + # snr = traits.Float + # gsr = traits.Dict + # tsnr = traits.Float + # fwhm = traits.Dict(desc='full width half-maximum measure') + # size = traits.Dict + summary = traits.Dict + + out_qc = traits.Dict(desc='output flattened dictionary with all measures') + + +class DiffusionQC(SimpleInterface): + """Computes :abbr:`QC (Quality Control)` measures on the input DWI EPI scan.""" + + input_spec = _DiffusionQCInputSpec + output_spec = _DiffusionQCOutputSpec + + def _run_interface(self, runtime): + from mriqc.qc import anatomical as aqc + from mriqc.qc import diffusion as dqc + # from mriqc.qc import functional as fqc + + # Get the mean EPI data and get it ready + b0nii = nb.load(self.inputs.in_b0) + b0data = np.round( + np.nan_to_num(np.asanyarray(b0nii.dataobj)), + 3, + ) + b0data[b0data < 0] = 0 + + # Get the FA data and get it ready + fanii = nb.load(self.inputs.in_fa) + fadata = np.round( + np.nan_to_num(np.asanyarray(fanii.dataobj)), + 3, + ) + + # Get brain mask data + msknii = nb.load(self.inputs.brain_mask) + mskdata = np.round( # Protect the thresholding with a rounding for stability + np.asanyarray(msknii.dataobj), + 3, + ) + if np.sum(mskdata) < 100: + raise RuntimeError( + 'Detected less than 100 voxels belonging to the brain mask. ' + 'MRIQC failed to process this dataset.' + ) + + # Get wm mask data + wmnii = nb.load(self.inputs.wm_mask) + wmdata = np.round( # Protect the thresholding with a rounding for stability + np.asanyarray(wmnii.dataobj), + 3, + ) + + # Get cc mask data + ccnii = nb.load(self.inputs.cc_mask) + ccdata = np.round( # Protect the thresholding with a rounding for stability + np.asanyarray(ccnii.dataobj), + 3, + ) + + # Get DWI data after splitting them by shell (DSI's data is clustered) + shelldata = [ + np.round( + np.asanyarray(nb.load(s).dataobj), + 4, + ) + for s in self.inputs.in_shells + ] + + # Summary stats + rois = { + 'fg': mskdata, + 'bg': ~mskdata, + 'cc': ccdata, + 'wm': wmdata, + } + stats = aqc.summary_stats(b0data, rois) + self._results['summary'] = stats + + self._results['cc_snr'] = dqc.cc_snr( + in_b0=b0data, + dwi_shells=shelldata, + cc_mask=ccdata, + b_values=self.inputs.in_bval, + b_vectors=self.inputs.in_bvec, + ) + + # FBER + self._results['fber'] = { + f'b{int(bval):d}': aqc.fber(bdata, mskdata.astype(np.uint8)) + for bval, bdata in zip(self.inputs.in_bval, shelldata) + } + + # EFC + self._results['efc'] = { + f'b{int(bval):d}': aqc.efc(bdata) + for bval, bdata in zip(self.inputs.in_bval, shelldata) + } + + # FD + fd_data = np.loadtxt(self.inputs.in_fd, skiprows=1) + num_fd = (fd_data > self.inputs.fd_thres).sum() + self._results['fd'] = { + 'mean': float(fd_data.mean()), + 'num': int(num_fd), + 'perc': float(num_fd * 100 / (len(fd_data) + 1)), + } + + self._results['out_qc'] = _flatten_dict(self._results) + + return runtime + class _ReadDWIMetadataOutputSpec(_ReadSidecarJSONOutputSpec): out_bvec_file = File(desc='corresponding bvec file') @@ -122,23 +324,26 @@ class _NumberOfShellsInputSpec(_BaseInterfaceInputSpec): class _NumberOfShellsOutputSpec(_TraitedSpec): models = traits.List(traits.Int, minlen=1, desc='number of shells ordered by model fit') - n_shells = traits.Int(desc='number of shels') + n_shells = traits.Int(desc='number of shells') out_data = traits.List( traits.Float, minlen=1, - desc="new b-values table (after 'shell-fying' DSI)", + desc="list of new b-values (e.g., after 'shell-ifying' DSI)", ) - b_values = traits.List(traits.Float, minlen=1, desc='estimated values of b') + b_values = traits.List( + traits.Float, + minlen=1, + desc='list of ``n_shells`` b-values associated with each shell (only nonzero)') b_masks = traits.List( traits.List(traits.Bool, minlen=1), minlen=1, - desc='b-value-wise masks') + desc='list of ``n_shells`` b-value-wise masks') b_indices = traits.List( traits.List(traits.Int, minlen=1), minlen=1, - desc='b-value-wise masks') + desc='list of ``n_shells`` b-value-wise indices lists') b_dict = traits.Dict( - traits.Int, traits.List(traits.Int), desc='b-values dictionary' + traits.Int, traits.List(traits.Int), desc='a map of b-values (including b=0) and masks' ) @@ -214,38 +419,59 @@ def _run_interface(self, runtime): return runtime -class _ExtractB0InputSpec(_BaseInterfaceInputSpec): +class _ExtractOrientationsInputSpec(_BaseInterfaceInputSpec): in_file = File(exists=True, mandatory=True, desc='dwi file') - b0_ixs = traits.List(traits.Int, mandatory=True, desc='Index of b0s') + indices = traits.List(traits.Int, mandatory=True, desc='indices to be extracted') + in_bvec_file = File(exists=True, desc='b-vectors file') -class _ExtractB0OutputSpec(_TraitedSpec): +class _ExtractOrientationsOutputSpec(_TraitedSpec): out_file = File(exists=True, desc='output b0 file') + out_bvec = traits.List( + traits.Tuple(traits.Float, traits.Float, traits.Float), + minlen=1, + desc='b-vectors', + ) -class ExtractB0(SimpleInterface): +class ExtractOrientations(SimpleInterface): """Extract all b=0 volumes from a dwi series.""" - input_spec = _ExtractB0InputSpec - output_spec = _ExtractB0OutputSpec + input_spec = _ExtractOrientationsInputSpec + output_spec = _ExtractOrientationsOutputSpec def _run_interface(self, runtime): from nipype.utils.filemanip import fname_presuffix out_file = fname_presuffix( self.inputs.in_file, - suffix='_b0', + suffix='_subset', newpath=runtime.cwd, ) - self._results['out_file'] = _extract_b0( - self.inputs.in_file, self.inputs.b0_ixs, out_path=out_file - ) + self._results['out_file'] = out_file + + img = nb.load(self.inputs.in_file) + bzeros = np.squeeze(np.asanyarray(img.dataobj)[..., self.inputs.indices]) + + hdr = img.header.copy() + hdr.set_data_shape(bzeros.shape) + hdr.set_xyzt_units('mm') + nb.Nifti1Image(bzeros, img.affine, hdr).to_filename(out_file) + + if isdefined(self.inputs.in_bvec_file): + bvecs = np.loadtxt(self.inputs.in_bvec_file)[:, self.inputs.indices].T + self._results['out_bvec'] = [tuple(row) for row in bvecs] + return runtime class _CorrectSignalDriftInputSpec(_BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='a 4D file with all low-b volumes') + in_file = File( + exists=True, + mandatory=True, + desc='a 4D file with (exclusively) realigned low-b volumes', + ) bias_file = File(exists=True, desc='a B1 bias field') brainmask_file = File(exists=True, desc='a 3D file of the brain mask') b0_ixs = traits.List(traits.Int, mandatory=True, desc='Index of b0s') @@ -254,10 +480,10 @@ class _CorrectSignalDriftInputSpec(_BaseInterfaceInputSpec): class _CorrectSignalDriftOutputSpec(_TraitedSpec): - out_file = File(desc='input file after drift correction') + out_file = File(desc='a 4D file with (exclusively) realigned, drift-corrected low-b volumes') out_full_file = File(desc='full DWI input after drift correction') - b0_drift = traits.List(traits.Float) - signal_drift = traits.List(traits.Float) + b0_drift = traits.List(traits.Float, desc='global signal evolution') + signal_drift = traits.List(traits.Float, desc='signal drift after fiting exp decay') class CorrectSignalDrift(SimpleInterface): @@ -338,7 +564,7 @@ def _run_interface(self, runtime): bval_list = np.rint(self.inputs.bvals).astype(int) bvals = np.unique(bval_list) img = nb.load(self.inputs.in_file) - data = np.array(img.dataobj, dtype=img.header.get_data_dtype()) + data = np.asanyarray(img.dataobj) self._results['out_file'] = [] @@ -424,10 +650,17 @@ class _DipyDTIInputSpec(_BaseInterfaceInputSpec): brainmask = File(exists=True, desc='brain mask file') free_water_model = traits.Bool(False, usedefault=True, desc='use free water model') b_threshold = traits.Float(1100, usedefault=True, desc='use only inner shells of the data') + decimals = traits.Int(3, usedefault=True, desc='round output maps for reliability') class _DipyDTIOutputSpec(_TraitedSpec): out_fa = File(exists=True, desc='output FA file') + out_fa_nans = File(exists=True, desc='binary mask of NaN values in the "raw" FA map') + out_fa_degenerate = File( + exists=True, + desc='binary mask of values outside [0, 1] in the "raw" FA map', + ) + out_cfa = File(exists=True, desc='output color FA file') out_md = File(exists=True, desc='output MD file') @@ -439,7 +672,7 @@ class DipyDTI(SimpleInterface): def _run_interface(self, runtime): from dipy.core.gradients import gradient_table_from_bvals_bvecs - from dipy.reconst.dti import TensorModel + from dipy.reconst.dti import TensorModel, color_fa, fractional_anisotropy from dipy.reconst.fwdti import FreeWaterTensorModel from nipype.utils.filemanip import fname_presuffix @@ -468,8 +701,14 @@ def _run_interface(self, runtime): ) # Extract the FA - fa_data = np.array(fwdtifit.fa, dtype='float32') - fa_data[np.isnan(fa_data)] = 0 + fa_data = fractional_anisotropy(fwdtifit.evals) + fa_nan_msk = np.isnan(fa_data) + fa_data[fa_nan_msk] = 0 + + # Round for stability + fa_data = np.round(fa_data, self.inputs.decimals) + degenerate_msk = (fa_data < 0) | (fa_data > 1.0) + # Clamp the FA to remove degenerate fa_data = np.clip(fa_data, 0, 1) fa_nii = nb.Nifti1Image( @@ -480,8 +719,6 @@ def _run_interface(self, runtime): fa_nii.header.set_xyzt_units('mm') fa_nii.header.set_intent('estimate', name='Fractional Anisotropy (FA)') - fa_nii.header['cal_max'] = 1.0 - fa_nii.header['cal_min'] = 0.0 self._results['out_fa'] = fname_presuffix( self.inputs.in_file, @@ -491,6 +728,66 @@ def _run_interface(self, runtime): fa_nii.to_filename(self._results['out_fa']) + # Write out degenerate and nans masks + fa_nan_nii = nb.Nifti1Image( + fa_nan_msk.astype(np.uint8), + img.affine, + None, + ) + + fa_nan_nii.header.set_xyzt_units('mm') + fa_nan_nii.header.set_intent('estimate', name='NaNs in the FA map mask') + fa_nan_nii.header['cal_max'] = 1 + fa_nan_nii.header['cal_min'] = 0 + + self._results['out_fa_nans'] = fname_presuffix( + self.inputs.in_file, + suffix='desc-fanans_mask', + newpath=runtime.cwd, + ) + fa_nan_nii.to_filename(self._results['out_fa_nans']) + + fa_degenerate_nii = nb.Nifti1Image( + degenerate_msk.astype(np.uint8), + img.affine, + None, + ) + + fa_degenerate_nii.header.set_xyzt_units('mm') + fa_degenerate_nii.header.set_intent( + 'estimate', + name='degenerate vectors in the FA map mask' + ) + fa_degenerate_nii.header['cal_max'] = 1 + fa_degenerate_nii.header['cal_min'] = 0 + + self._results['out_fa_degenerate'] = fname_presuffix( + self.inputs.in_file, + suffix='desc-fadegenerate_mask', + newpath=runtime.cwd, + ) + fa_degenerate_nii.to_filename(self._results['out_fa_degenerate']) + + # Extract the color FA + cfa_data = color_fa(fa_data, fwdtifit.evecs) + cfa_nii = nb.Nifti1Image( + np.clip(cfa_data, a_min=0.0, a_max=1.0), + img.affine, + None, + ) + + cfa_nii.header.set_xyzt_units('mm') + cfa_nii.header.set_intent('estimate', name='Fractional Anisotropy (FA)') + cfa_nii.header['cal_max'] = 1.0 + cfa_nii.header['cal_min'] = 0.0 + + self._results['out_cfa'] = fname_presuffix( + self.inputs.in_file, + suffix='cfa', + newpath=runtime.cwd, + ) + cfa_nii.to_filename(self._results['out_cfa']) + # Extract the AD self._results['out_md'] = fname_presuffix( self.inputs.in_file, @@ -511,6 +808,124 @@ def _run_interface(self, runtime): return runtime +class _CCSegmentationInputSpec(_BaseInterfaceInputSpec): + in_fa = File(exists=True, mandatory=True, desc='fractional anisotropy (FA) file') + in_cfa = File(exists=True, mandatory=True, desc='color FA file') + min_rgb = traits.Tuple((0.4, 0.008, 0.008), types=(traits.Float,) * 3, + usedefault=True, desc='minimum RGB within the CC') + max_rgb = traits.Tuple((1.1, 0.25, 0.25), types=(traits.Float,) * 3, + usedefault=True, desc='maximum RGB within the CC') + wm_threshold = traits.Float(0.35, usedefault=True, desc='WM segmentation threshold') + clean_mask = traits.Bool(False, usedefault=True, desc='run a final cleanup step on mask') + + +class _CCSegmentationOutputSpec(_TraitedSpec): + out_mask = File(exists=True, desc='output mask of the corpus callosum') + wm_mask = File(exists=True, desc='output mask of the white-matter (thresholded)') + wm_finalmask = File(exists=True, desc='output mask of the white-matter after binary opening') + + +class CCSegmentation(SimpleInterface): + """Computes :abbr:`QC (Quality Control)` measures on the input DWI EPI scan.""" + + input_spec = _CCSegmentationInputSpec + output_spec = _CCSegmentationOutputSpec + + def _run_interface(self, runtime): + from skimage.measure import label + + self._results['out_mask'] = fname_presuffix( + self.inputs.in_cfa, + suffix='ccmask', + newpath=runtime.cwd, + ) + self._results['wm_mask'] = fname_presuffix( + self.inputs.in_cfa, + suffix='wmmask', + newpath=runtime.cwd, + ) + self._results['wm_finalmask'] = fname_presuffix( + self.inputs.in_cfa, + suffix='wmfinalmask', + newpath=runtime.cwd, + ) + + fa_nii = nb.load(self.inputs.in_fa) + fa_data = np.round(fa_nii.get_fdata(dtype='float32'), 4) + fa_labels = label((fa_data > self.inputs.wm_threshold).astype(np.uint8)) + wm_mask = fa_labels == np.argmax(np.bincount(fa_labels.flat)[1:]) + 1 + + # Write out binary WM mask + wm_mask_nii = nb.Nifti1Image( + wm_mask.astype(np.uint8), + fa_nii.affine, + None, + ) + wm_mask_nii.header.set_xyzt_units('mm') + wm_mask_nii.header.set_intent('estimate', name='white-matter mask (FA thresholded)') + wm_mask_nii.header['cal_max'] = 1 + wm_mask_nii.header['cal_min'] = 0 + wm_mask_nii.to_filename(self._results['wm_mask']) + + # Massage FA with greyscale mathematical morphology + struct = nd.generate_binary_structure(wm_mask.ndim, wm_mask.ndim - 1) + # Perform a closing followed by opening operations on the FA. + wm_mask = nd.grey_closing( + fa_data, + structure=struct, + ) + wm_mask = nd.grey_opening( + wm_mask, + structure=struct, + ) + + fa_labels = label(( + np.round(wm_mask, 4) > self.inputs.wm_threshold + ).astype(np.uint8)) + wm_mask = fa_labels == np.argmax(np.bincount(fa_labels.flat)[1:]) + 1 + + # Write out binary WM mask after binary opening + wm_mask_nii = nb.Nifti1Image( + wm_mask.astype(np.uint8), + fa_nii.affine, + wm_mask_nii.header, + ) + wm_mask_nii.header.set_intent('estimate', name='white-matter mask after binary opening') + wm_mask_nii.to_filename(self._results['wm_finalmask']) + + cfa_data = np.round( + nb.load(self.inputs.in_cfa).get_fdata(dtype='float32'), 4 + ) + for i in range(cfa_data.shape[-1]): + cfa_data[..., i] = nd.grey_closing( + cfa_data[..., i], + structure=struct, + ) + cfa_data[..., i] = nd.grey_opening( + cfa_data[..., i], + structure=struct, + ) + + cc_mask = segment_corpus_callosum( + in_cfa=cfa_data, + mask=wm_mask, + min_rgb=self.inputs.min_rgb, + max_rgb=self.inputs.max_rgb, + clean_mask=self.inputs.clean_mask, + ) + cc_mask_nii = nb.Nifti1Image( + cc_mask.astype(np.uint8), + fa_nii.affine, + None, + ) + cc_mask_nii.header.set_xyzt_units('mm') + cc_mask_nii.header.set_intent('estimate', name='corpus callosum mask') + cc_mask_nii.header['cal_max'] = 1 + cc_mask_nii.header['cal_min'] = 0 + cc_mask_nii.to_filename(self._results['out_mask']) + return runtime + + def _rms(estimator, X): """ Callable to pass to GridSearchCV that will calculate a distance score. @@ -528,20 +943,160 @@ def _rms(estimator, X): return -np.sqrt(distance**2).sum() -def _extract_b0(in_file, b0_ixs, out_path=None): - """Extract the *b0* volumes from a DWI dataset.""" - if out_path is None: - out_path = fname_presuffix(in_file, suffix='_b0') +def _exp_func(t, A, K, C): + return A * np.exp(K * t) + C + - img = nb.load(in_file) - bzeros = np.squeeze(np.asanyarray(img.dataobj)[..., b0_ixs]) +def segment_corpus_callosum( + in_cfa: np.ndarray, + mask: np.ndarray, + min_rgb: tuple[float, float, float] = (0.6, 0.0, 0.0), + max_rgb: tuple[float, float, float] = (1.0, 0.1, 0.1), + clean_mask: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """ + Segments the corpus callosum (CC) from a color FA map. + + Parameters + ---------- + in_cfa : :obj:`~numpy.ndarray` + The color FA (cFA) map. + mask : :obj:`~numpy.ndarray` (bool, 3D) + A white matter mask used to define the initial bounding box. + min_rgb : :obj:`tuple`, optional + Minimum RGB values. + max_rgb : :obj:`tuple`, optional + Maximum RGB values. + clean_mask : :obj:`bool`, optional + Whether the CC mask is finally cleaned-up for spurious off voxels with + :obj:`dipy.segment.mask.clean_cc_mask` + + Returns + ------- + cc_mask: :obj:`~numpy.ndarray` + The final binary mask of the segmented CC. + + Notes + ----- + This implementation was derived from + :obj:`dipy.segment.mask.segment_from_cfa`. - hdr = img.header.copy() - hdr.set_data_shape(bzeros.shape) - hdr.set_xyzt_units('mm') - nb.Nifti1Image(bzeros, img.affine, hdr).to_filename(out_path) - return out_path + """ + from dipy.segment.mask import bounding_box + + # Prepare a bounding box of the CC + cc_box = np.zeros_like(mask, dtype=bool) + mins, maxs = bounding_box(mask) # mask needs to be volume + mins = np.array(mins) + maxs = np.array(maxs) + diff = (maxs - mins) // 5 + bounds_min = mins + diff + bounds_max = maxs - diff + cc_box[ + bounds_min[0]:bounds_max[0], + bounds_min[1]:bounds_max[1], + bounds_min[2]:bounds_max[2] + ] = True + + min_rgb = np.array(min_rgb) + max_rgb = np.array(max_rgb) + + # Threshold color FA + cc_mask = np.all( + (in_cfa >= min_rgb[None, :]) & (in_cfa <= max_rgb[None, :]), + axis=-1, + ) + # Apply bounding box and WM mask + cc_mask *= (cc_box & mask) -def _exp_func(t, A, K, C): - return A * np.exp(K * t) + C + struct = nd.generate_binary_structure(cc_mask.ndim, cc_mask.ndim - 1) + # Perform a closing followed by opening operations on the FA. + cc_mask = nd.binary_closing( + cc_mask, + structure=struct, + ) + cc_mask = nd.binary_opening( + cc_mask, + structure=struct, + ) + + if clean_mask: + from dipy.segment.mask import clean_cc_mask + + cc_mask = clean_cc_mask(cc_mask) + return cc_mask + + +def get_spike_mask( + data: np.ndarray, + z_threshold: float = 3.0, + grouping_vals: np.ndarray | None = None, + bmag: int | None = None, +) -> np.ndarray: + """ + Creates a binary mask classifying voxels in the data array as spike or non-spike. + + This function identifies voxels with signal intensities exceeding a threshold based + on standard deviations above the mean. The threshold can be applied globally to + the entire data array, or it can be calculated for groups of voxels defined by + the ``grouping_vals`` parameter. + + Parameters + ---------- + data : :obj:`~numpy.ndarray` + The data array to be thresholded. + z_threshold : :obj:`float`, optional (default=3.0) + The number of standard deviations to use above the mean as the threshold + multiplier. + grouping_vals : :obj:`~numpy.ndarray`, optional + If provided, this array is used to group voxels for thresholding. Voxels + with the same value in ``grouping_vals`` are considered to belong to the same + group. The threshold will be calculated independently for each group. + - If ``grouping_vals`` has the same shape as ``data`` (4D), it is assumed to be + a mask where each voxel value indicates the group it belongs to. + - If ``grouping_vals`` has a 3D shape, it is assumed to represent b-values + corresponding to each voxel in the 4D ``data`` array. In this case, voxels + with the same b-value are grouped together. + bmag : int, optional + The order of magnitude for b-value rounding (used only if + ``grouping_vals`` is provided as b-values). Default: None (derived from max b-value). + + Returns: + ------- + spike_mask : :obj:`~numpy.ndarray` + A binary mask where ``True`` values indicate voxels classified as spikes and + ``False`` values indicate non-spikes. The mask has the same shape as the input + data array. + + """ + from dipy.core.gradients import round_bvals, unique_bvals_magnitude + + if grouping_vals is None: + threshold = np.round((z_threshold * np.std(data)) + np.mean(data), 3) + spike_mask = np.round(data, 3) > threshold + return spike_mask + + threshold_mask = np.zeros(data.shape) + + rounded_grouping_vals = round_bvals(grouping_vals, bmag) + gvals = unique_bvals_magnitude(grouping_vals, bmag) + + if grouping_vals.shape == data.shape: + for gval in gvals: + gval_data = data[rounded_grouping_vals == gval] + gval_threshold = ((z_threshold * np.std(gval_data)) + + np.mean(gval_data)) + threshold_mask[rounded_grouping_vals == gval] = ( + gval_threshold * np.ones(gval_data.shape)) + else: + for gval in gvals: + gval_data = data[..., rounded_grouping_vals == gval] + gval_threshold = ((z_threshold * np.std(gval_data)) + + np.mean(gval_data)) + threshold_mask[..., rounded_grouping_vals == gval] = ( + gval_threshold * np.ones(gval_data.shape)) + + spike_mask = data > threshold_mask + + return spike_mask diff --git a/mriqc/qc/anatomical.py b/mriqc/qc/anatomical.py index 06afa742b..f613d2432 100644 --- a/mriqc/qc/anatomical.py +++ b/mriqc/qc/anatomical.py @@ -612,7 +612,7 @@ def summary_stats( for label, probmap in pvms.items(): wstats = DescrStatsW( data=np.round(data.reshape(-1), rprec_data), - weights=np.round(probmap.reshape(-1), rprec_prob), + weights=np.round(probmap.astype(np.float32).reshape(-1), rprec_prob), ) nvox = probmap.sum() p05, median, p95 = wstats.quantile(np.array([0.05, 0.50, 0.95]), return_pandas=False) diff --git a/mriqc/qc/diffusion.py b/mriqc/qc/diffusion.py index c787e6cd2..a1b4e7ee1 100644 --- a/mriqc/qc/diffusion.py +++ b/mriqc/qc/diffusion.py @@ -215,35 +215,41 @@ def cc_snr( * The second element is the best-case SNR (float). """ + cc_mask = cc_mask > 0 # Ensure it's a boolean mask std_signal = in_b0[cc_mask].std() cc_snr_estimates = {} + xyz = np.eye(3) + + b_values = np.rint(b_values).astype(np.uint16) + # Shell-wise calculation for bval, bvecs, shell_data in zip(b_values, b_vectors, dwi_shells): if bval == 0: - cc_snr_estimates[int(bval)] = in_b0[cc_mask].mean() / std_signal + cc_snr_estimates[f'b{bval:d}'] = in_b0[cc_mask].mean() / std_signal continue + shell_data = shell_data[cc_mask] + # Find main directions of diffusion axis_X = np.argmin(np.sum( - (bvecs - np.array([1, 0, 0])) ** 2, axis=-1)) + (bvecs - xyz[0, :]) ** 2, axis=-1)) axis_Y = np.argmin(np.sum( - (bvecs - np.array([0, 1, 0])) ** 2, axis=-1)) + (bvecs - xyz[1, :]) ** 2, axis=-1)) axis_Z = np.argmin(np.sum( - (bvecs - np.array([0, 0, 1])) ** 2, axis=-1)) + (bvecs - xyz[2, :]) ** 2, axis=-1)) data_X = shell_data[..., axis_X] data_Y = shell_data[..., axis_Y] data_Z = shell_data[..., axis_Z] - mean_signal_X = np.mean(data_X[cc_mask]) - mean_signal_Y = np.mean(data_Y[cc_mask]) - mean_signal_Z = np.mean(data_Z[cc_mask]) + mean_signal_worst = np.mean(data_X) + mean_signal_best = 0.5 * (np.mean(data_Y) + np.mean(data_Z)) - cc_snr_estimates[int(bval)] = ( - np.mean(mean_signal_X / std_signal), # worst - np.mean(np.mean(mean_signal_Y, mean_signal_Z) / std_signal), # best + cc_snr_estimates[f'b{bval:d}'] = ( + np.mean(mean_signal_worst / std_signal), + np.mean(mean_signal_best / std_signal), ) return cc_snr_estimates diff --git a/mriqc/workflows/diffusion/base.py b/mriqc/workflows/diffusion/base.py index 74e4a67ce..8ae81d11b 100644 --- a/mriqc/workflows/diffusion/base.py +++ b/mriqc/workflows/diffusion/base.py @@ -68,11 +68,13 @@ def dmri_qc_workflow(name='dwiMRIQC'): from nipype.interfaces.afni import Volreg from nipype.interfaces.mrtrix3.preprocess import DWIDenoise from niworkflows.interfaces.header import SanitizeImage + from niworkflows.interfaces.images import RobustAverage from mriqc.interfaces.diffusion import ( + CCSegmentation, CorrectSignalDrift, DipyDTI, - ExtractB0, + ExtractOrientations, FilterShells, NumberOfShells, ReadDWIMetadata, @@ -123,16 +125,14 @@ def dmri_qc_workflow(name='dwiMRIQC'): # Workflow -------------------------------------------------------- - # 1. Read metadata & bvec/bval, estimate number of shells, extract and split B0s + # Read metadata & bvec/bval, estimate number of shells, extract and split B0s meta = pe.Node(ReadDWIMetadata(index_db=config.execution.bids_database_dir), name='metadata') shells = pe.Node(NumberOfShells(), name='shells') - get_shells = pe.MapNode(ExtractB0(), name='get_shells', iterfield=['b0_ixs']) - hmc_shells = pe.MapNode( - Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'), - name='hmc_shells', - mem_gb=mem_gb * 2.5, - iterfield=['in_file'], - ) + drift = pe.Node(CorrectSignalDrift(), name='drift') + get_lowb = pe.Node(ExtractOrientations(), name='get_lowb') + + # Generate B0 reference + dwi_ref = pe.Node(RobustAverage(mc_method=None), name='dwi_ref') hmc_b0 = pe.Node( Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'), @@ -140,18 +140,27 @@ def dmri_qc_workflow(name='dwiMRIQC'): mem_gb=mem_gb * 2.5, ) - drift = pe.Node(CorrectSignalDrift(), name='drift') - - # 2. Generate B0 reference - dwi_reference_wf = init_dmriref_wf(name='dwi_reference_wf') + # Shell-wise hmc not functional (yet?) + # hmc_shells = pe.MapNode( + # Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'), + # name='hmc_shells', + # mem_gb=mem_gb * 2.5, + # iterfield=['in_file'], + # ) - # 3. Calculate brainmask + # Calculate brainmask dmri_bmsk = dmri_bmsk_workflow(omp_nthreads=config.nipype.omp_nthreads) - # 4. HMC: head motion correct + # HMC: head motion correct hmcwf = hmc_workflow() - # 5. Split shells and compute some stats + get_hmc_shells = pe.MapNode( + ExtractOrientations(), + name='get_hmc_shells', + iterfield=['indices'], + ) + + # Split shells and compute some stats averages = pe.MapNode( WeightedStat(), name='averages', @@ -165,7 +174,7 @@ def dmri_qc_workflow(name='dwiMRIQC'): iterfield=['in_weights'], ) - # 6. Fit DTI model + # Fit DTI model dti_filter = pe.Node(FilterShells(), name='dti_filter') dwidenoise = pe.Node( DWIDenoise( @@ -180,13 +189,16 @@ def dmri_qc_workflow(name='dwiMRIQC'): name='dti', ) - # 7. EPI to MNI registration - ema = epi_mni_align() + # Calculate CC mask + cc_mask = pe.Node(CCSegmentation(), name="cc_mask") + + # EPI to MNI registration + spatial_norm = epi_mni_align() - # 8. Compute IQMs - iqmswf = compute_iqms() + # Compute IQMs + iqms_wf = compute_iqms() - # 9. Generate outputs + # Generate outputs dwi_report_wf = init_dwi_report_wf() # fmt: off @@ -196,24 +208,21 @@ def dmri_qc_workflow(name='dwiMRIQC'): (inputnode, dwi_report_wf, [ ('in_file', 'inputnode.name_source'), ]), - (datalad_get, iqmswf, [('in_file', 'inputnode.in_file')]), + (datalad_get, iqms_wf, [('in_file', 'inputnode.in_file')]), (datalad_get, sanitize, [('in_file', 'in_file')]), - (sanitize, dwi_reference_wf, [('out_file', 'inputnode.in_file')]), - (shells, dwi_reference_wf, [(('b_masks', _first), 'inputnode.t_mask')]), + (sanitize, dwi_ref, [('out_file', 'in_file')]), + (shells, dwi_ref, [(('b_masks', _first), 't_mask')]), (meta, shells, [('out_bval_file', 'in_bvals')]), (sanitize, drift, [('out_file', 'full_epi')]), - (shells, get_shells, [('b_indices', 'b0_ixs')]), - (sanitize, get_shells, [('out_file', 'in_file')]), + (shells, get_lowb, [(('b_indices', _first), 'indices')]), + (sanitize, get_lowb, [('out_file', 'in_file')]), (meta, drift, [('out_bval_file', 'bval_file')]), - (get_shells, hmc_shells, [(('out_file', _all_but_first), 'in_file')]), - (get_shells, hmc_b0, [(('out_file', _first), 'in_file')]), - (dwi_reference_wf, hmc_b0, [('outputnode.ref_file', 'basefile')]), + (get_lowb, hmc_b0, [('out_file', 'in_file')]), + (dwi_ref, hmc_b0, [('out_file', 'basefile')]), (hmc_b0, drift, [('out_file', 'in_file')]), (shells, drift, [(('b_indices', _first), 'b0_ixs')]), - (dwi_reference_wf, dmri_bmsk, [('outputnode.ref_file', 'inputnode.in_files')]), - (dwi_reference_wf, ema, [('outputnode.ref_file', 'inputnode.epi_mean')]), + (dwi_ref, dmri_bmsk, [('out_file', 'inputnode.in_files')]), (dmri_bmsk, drift, [('outputnode.out_mask', 'brainmask_file')]), - (dmri_bmsk, ema, [('outputnode.out_mask', 'inputnode.epi_mask')]), (drift, hmcwf, [('out_full_file', 'inputnode.in_file')]), (drift, averages, [('out_full_file', 'in_file')]), (drift, stddev, [('out_full_file', 'in_file')]), @@ -229,19 +238,37 @@ def dmri_qc_workflow(name='dwiMRIQC'): (dmri_bmsk, dwidenoise, [('outputnode.out_mask', 'mask')]), (dwidenoise, dti, [('out_file', 'in_file')]), (dmri_bmsk, dti, [('outputnode.out_mask', 'brainmask')]), - (hmcwf, outputnode, [('outputnode.out_fd', 'out_fd')]), - (shells, iqmswf, [('n_shells', 'inputnode.n_shells'), - ('b_values', 'inputnode.b_values')]), + (meta, get_hmc_shells, [('out_bvec_file', 'in_bvec_file')]), + (shells, get_hmc_shells, [('b_indices', 'indices')]), + (hmcwf, get_hmc_shells, [('outputnode.out_file', 'in_file')]), + (dti, cc_mask, [('out_fa', 'in_fa'), + ('out_cfa', 'in_cfa')]), + (averages, iqms_wf, [(('out_file', _first), 'inputnode.in_b0')]), + (hmcwf, iqms_wf, [('outputnode.out_fd', 'inputnode.framewise_displacement')]), + (dti, iqms_wf, [('out_fa', 'inputnode.in_fa'), + ('out_cfa', 'inputnode.in_cfa'), + ('out_fa_nans', 'inputnode.in_fa_nans'), + ('out_fa_degenerate', 'inputnode.in_fa_degenerate'), + ('out_md', 'inputnode.in_md')]), + (dmri_bmsk, iqms_wf, [('outputnode.out_mask', 'inputnode.brain_mask')]), + (cc_mask, iqms_wf, [('out_mask', 'inputnode.cc_mask'), + ('wm_finalmask', 'inputnode.wm_mask')]), + (shells, iqms_wf, [('n_shells', 'inputnode.n_shells'), + ('b_values', 'inputnode.b_values')]), + (get_hmc_shells, iqms_wf, [('out_file', 'inputnode.in_shells'), + ('out_bvec', 'inputnode.in_bvec')]), + (dwi_ref, spatial_norm, [('out_file', 'inputnode.epi_mean')]), + (dmri_bmsk, spatial_norm, [('outputnode.out_mask', 'inputnode.epi_mask')]), (dwidenoise, dwi_report_wf, [('noise', 'inputnode.in_noise')]), (shells, dwi_report_wf, [('b_dict', 'inputnode.in_bdict')]), - (dmri_bmsk, dwi_report_wf, [('outputnode.out_mask', 'inputnode.brainmask')]), + (dmri_bmsk, dwi_report_wf, [('outputnode.out_mask', 'inputnode.brain_mask')]), (shells, dwi_report_wf, [('b_values', 'inputnode.in_shells')]), (averages, dwi_report_wf, [('out_file', 'inputnode.in_avgmap')]), (stddev, dwi_report_wf, [('out_file', 'inputnode.in_stdmap')]), (drift, dwi_report_wf, [('out_full_file', 'inputnode.in_epi')]), (dti, dwi_report_wf, [('out_fa', 'inputnode.in_fa'), ('out_md', 'inputnode.in_md')]), - (ema, dwi_report_wf, [('outputnode.epi_parc', 'inputnode.in_parcellation')]), + (spatial_norm, dwi_report_wf, [('outputnode.epi_parc', 'inputnode.in_parcellation')]), ]) # fmt: on return workflow @@ -263,6 +290,7 @@ def compute_iqms(name='ComputeIQMs'): from mriqc.interfaces import IQMFileSink from mriqc.interfaces.reports import AddProvenance + from mriqc.interfaces.diffusion import DiffusionQC # from mriqc.workflows.utils import _tofloat, get_fwhmx # mem_gb = config.workflow.biggest_file_gb @@ -271,9 +299,21 @@ def compute_iqms(name='ComputeIQMs'): inputnode = pe.Node( niu.IdentityInterface( fields=[ - 'in_file', 'n_shells', 'b_values', + 'in_file', + 'in_shells', + 'in_bvec', + 'in_b0', + 'in_fa', + 'in_cfa', + 'in_fa_nans', + 'in_fa_degenerate', + 'in_md', + 'brain_mask', + 'wm_mask', + 'cc_mask', + 'framewise_displacement', ] ), name='inputnode', @@ -290,6 +330,8 @@ def compute_iqms(name='ComputeIQMs'): meta = pe.Node(ReadSidecarJSON(index_db=config.execution.bids_database_dir), name='metadata') + measures = pe.Node(DiffusionQC(), name='measures') + addprov = pe.Node( AddProvenance(modality='dwi'), name='provenance', @@ -313,6 +355,20 @@ def compute_iqms(name='ComputeIQMs'): ('n_shells', 'NumberOfShells'), ('b_values', 'b-values')]), (inputnode, meta, [('in_file', 'in_file')]), + (inputnode, measures, [('in_file', 'in_file'), + ('b_values', 'in_bval'), + ('in_shells', 'in_shells'), + ('in_bvec', 'in_bvec'), + ('in_b0', 'in_b0'), + ('brain_mask', 'brain_mask'), + ('wm_mask', 'wm_mask'), + ('cc_mask', 'cc_mask'), + ('in_fa', 'in_fa'), + ('in_md', 'in_md'), + ('in_cfa', 'in_cfa'), + ('in_fa_nans', 'in_fa_nans'), + ('in_fa_degenerate', 'in_fa_degenerate'), + ('framewise_displacement', 'in_fd')]), (inputnode, addprov, [('in_file', 'in_file')]), (addprov, datasink, [('out_prov', 'provenance')]), (meta, datasink, [('subject', 'subject_id'), @@ -324,101 +380,9 @@ def compute_iqms(name='ComputeIQMs'): ('out_dict', 'metadata')]), (datasink, outputnode, [('out_file', 'out_file')]), (meta, outputnode, [('out_dict', 'meta_sidecar')]), + (measures, datasink, [("out_qc", "root")]), ]) # fmt: on - - # Set FD threshold - # inputnode.inputs.fd_thres = config.workflow.fd_thres - - # # AFNI quality measures - # fwhm_interface = get_fwhmx() - # fwhm = pe.Node(fwhm_interface, name="smoothness") - # # fwhm.inputs.acf = True # add when AFNI >= 16 - # measures = pe.Node(FunctionalQC(), name="measures", mem_gb=mem_gb * 3) - - # # fmt: off - # workflow.connect([ - # (inputnode, measures, [("epi_mean", "in_epi"), - # ("brainmask", "in_mask"), - # ("hmc_epi", "in_hmc"), - # ("hmc_fd", "in_fd"), - # ("fd_thres", "fd_thres"), - # ("in_tsnr", "in_tsnr")]), - # (inputnode, fwhm, [("epi_mean", "in_file"), - # ("brainmask", "mask")]), - # (fwhm, measures, [(("fwhm", _tofloat), "in_fwhm")]), - # (measures, datasink, [("out_qc", "root")]), - # ]) - # # fmt: on - return workflow - - -def init_dmriref_wf( - in_file=None, - name='init_dmriref_wf', -): - """ - Build a workflow that generates reference images for a dMRI series. - - The raw reference image is the target of :abbr:`HMC (head motion correction)`, and a - contrast-enhanced reference is the subject of distortion correction, as well as - boundary-based registration to T1w and template spaces. - - This workflow assumes only one dMRI file has been passed. - - Workflow Graph - .. workflow:: - :graph2use: orig - :simple_form: yes - - from mriqc.workflows.diffusion.base import init_dmriref_wf - wf = init_dmriref_wf() - - Parameters - ---------- - in_file : :obj:`str` - dMRI series NIfTI file - ------ - in_file : str - series NIfTI file - - Outputs - ------- - in_file : str - Validated DWI series NIfTI file - ref_file : str - Reference image to which DWI series is motion corrected - """ - from niworkflows.interfaces.header import ValidateImage - from niworkflows.interfaces.images import RobustAverage - - workflow = pe.Workflow(name=name) - inputnode = pe.Node(niu.IdentityInterface(fields=['in_file', 't_mask']), name='inputnode') - outputnode = pe.Node( - niu.IdentityInterface(fields=['in_file', 'ref_file', 'validation_report']), - name='outputnode', - ) - - # Simplify manually setting input image - if in_file is not None: - inputnode.inputs.in_file = in_file - - val_bold = pe.Node( - ValidateImage(), - name='val_bold', - mem_gb=DEFAULT_MEMORY_MIN_GB, - ) - - gen_avg = pe.Node(RobustAverage(mc_method=None), name='gen_avg', mem_gb=1) - # fmt: off - workflow.connect([ - (inputnode, val_bold, [('in_file', 'in_file')]), - (inputnode, gen_avg, [('t_mask', 't_mask')]), - (val_bold, gen_avg, [('out_file', 'in_file')]), - (gen_avg, outputnode, [('out_file', 'ref_file')]), - ]) - # fmt: on - return workflow diff --git a/mriqc/workflows/diffusion/output.py b/mriqc/workflows/diffusion/output.py index 078494a72..de4062c46 100644 --- a/mriqc/workflows/diffusion/output.py +++ b/mriqc/workflows/diffusion/output.py @@ -58,7 +58,8 @@ def init_dwi_report_wf(name='dwi_report_wf'): niu.IdentityInterface( fields=[ 'in_epi', - 'brainmask', + 'brain_mask', + 'cc_mask', 'in_avgmap', 'in_stdmap', 'in_shells', @@ -167,12 +168,12 @@ def _gen_entity(inlist): workflow.connect([ (inputnode, mosaic_snr, [('in_avgmap', 'before'), ('in_stdmap', 'after'), - ('brainmask', 'wm_seg')]), + ('brain_mask', 'wm_seg')]), (inputnode, mosaic_noise, [('in_avgmap', 'in_file')]), (inputnode, mosaic_fa, [('in_fa', 'in_file'), - ('brainmask', 'bbox_mask_file')]), + ('brain_mask', 'bbox_mask_file')]), (inputnode, mosaic_md, [('in_md', 'in_file'), - ('brainmask', 'bbox_mask_file')]), + ('brain_mask', 'bbox_mask_file')]), (inputnode, ds_report_snr, [('name_source', 'source_file'), (('in_shells', _gen_entity), 'bval')]), (inputnode, ds_report_noise, [('name_source', 'source_file'), @@ -209,7 +210,7 @@ def _gen_entity(inlist): ('in_bdict', 'b_indices')]), (inputnode, ds_report_hm, [('name_source', 'source_file')]), (inputnode, estimate_sigma, [('in_noise', 'in_file'), - ('brainmask', 'mask')]), + ('brain_mask', 'mask')]), (estimate_sigma, plot_heatmap, [('out', 'sigma')]), (get_wm, plot_heatmap, [('out', 'mask_file')]), (plot_heatmap, ds_report_hm, [('out_file', 'in_file')]), @@ -248,8 +249,8 @@ def _gen_entity(inlist): ('outliers', 'outliers'), (('meta_sidecar', _get_tr), 'tr')]), (inputnode, parcels, [('epi_parc', 'segmentation')]), - (inputnode, dilated_mask, [('brainmask', 'in_mask')]), - (inputnode, subtract_mask, [('brainmask', 'in_subtract')]), + (inputnode, dilated_mask, [('brain_mask', 'in_mask')]), + (inputnode, subtract_mask, [('brain_mask', 'in_subtract')]), (dilated_mask, subtract_mask, [('out_mask', 'in_base')]), (subtract_mask, parcels, [('out_mask', 'crown_mask')]), (parcels, bigplot, [('out', 'in_segm')]), @@ -357,9 +358,9 @@ def _gen_entity(inlist): (inputnode, ds_report_norm, [('mni_report', 'in_file'), ('name_source', 'source_file')]), (inputnode, plot_bmask, [('epi_mean', 'in_file'), - ('brainmask', 'in_contours')]), + ('brain_mask', 'in_contours')]), (inputnode, mosaic_zoom, [('epi_mean', 'in_file'), - ('brainmask', 'bbox_mask_file')]), + ('brain_mask', 'bbox_mask_file')]), (inputnode, mosaic_noise, [('epi_mean', 'in_file')]), (inputnode, ds_report_zoomed, [('name_source', 'source_file')]), (inputnode, ds_report_background, [('name_source', 'source_file')]),