Skip to content

Commit

Permalink
Merge branch 'start_on_diffusion' of github.com:arokem/mriqc into sta…
Browse files Browse the repository at this point in the history
…rt_on_diffusion
  • Loading branch information
arokem committed Feb 7, 2024
2 parents e1c97dc + 25f0962 commit 68cb9b0
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 29 deletions.
175 changes: 159 additions & 16 deletions mriqc/qc/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
Expand All @@ -21,9 +22,6 @@
# https://www.nipreps.org/community/licensing/
#

import numpy as np


"""
Image quality metrics for diffusion MRI data
============================================
Expand All @@ -34,16 +32,14 @@
from dipy.core.gradients import GradientTable
from dipy.reconst.dti import TensorModel
from dipy.denoise.noise_estimate import piesno

def get_spike_mask(data, z_threshold=3):
"""
Return binary mask of spike/no spike
from dipy.core.gradients import unique_bvals_magnitude
from dipy.core.gradients import round_bvals
from dipy.segment.mask import segment_from_cfa
from dipy.segment.mask import bounding_box

def noise_func(img, gtab):
pass

def noise_b0(data, gtab, mask=None):
"""
Estimate noise in raw dMRI based on b0 variance.
Expand Down Expand Up @@ -80,7 +76,7 @@ def noise_piesno(data, n_channels=4):
return sigma, mask


def cc_snr(data, gtab):
def cc_snr(data, gtab, bmag=None, mask=None):
"""
Calculate worse-/best-case signal-to-noise ratio in the corpus callosum
Expand All @@ -90,21 +86,31 @@ def cc_snr(data, gtab):
gtab : GradientTable class instance or tuple
bmag : int
From dipy.core.gradients:
The order of magnitude that the bvalues have to differ to be
considered an unique b-value. B-values are also rounded up to
this order of magnitude. Default: derive this value from the
maximal b-value provided: $bmag=log_{10}(max(bvals)) - 1$.
mask : numpy array
Boolean brain mask
"""
if isinstance(gtab, GradientTable):
pass

# XXX Per-shell calculation
if mask is None:
mask = np.ones(data.shape[:3])

tenmodel = TensorModel(gtab)
tensorfit = tenmodel.fit(data, mask=mask)

from dipy.segment.mask import segment_from_cfa
from dipy.segment.mask import bounding_box

threshold = (0.6, 1, 0, 0.1, 0, 0.1)
CC_box = np.zeros_like(data[..., 0])

mins, maxs = bounding_box(mask)
mins, maxs = bounding_box(mask) #mask needs to be volume
mins = np.array(mins)
maxs = np.array(maxs)
diff = (maxs - mins) // 4
Expand All @@ -118,4 +124,141 @@ def cc_snr(data, gtab):
mask_cc_part, cfa = segment_from_cfa(tensorfit, CC_box, threshold,
return_cfa=True)

mean_signal = np.mean(data[mask_cc_part], axis=0)
b0_data = data[..., gtab.b0s_mask]
std_signal = np.std(b0_data[mask_cc_part], axis=-1)

# Per-shell calculation
rounded_bvals = round_bvals(gtab.bvals, bmag)
bvals = unique_bvals_magnitude(gtab.bvals, bmag)

cc_snr_best = np.zeros(gtab.bvals.shape)
cc_snr_worst = np.zeros(gtab.bvals.shape)

for ind, bval in enumerate(bvals):
if bval == 0:
mean_signal = np.mean(data[..., rounded_bvals == 0], axis=-1)
cc_snr_worst[ind] = np.mean(mean_signal/std_signal)
cc_snr_best[ind] = np.mean(mean_signal/std_signal)
continue

bval_data = data[..., rounded_bvals == bval]
bval_bvecs = gtab.bvecs[rounded_bvals == bval]

axis_X = np.argmin(np.sum((bval_bvecs-np.array([1, 0, 0]))**2, axis=-1))
axis_Y = np.argmin(np.sum((bval_bvecs-np.array([0, 1, 0]))**2, axis=-1))
axis_Z = np.argmin(np.sum((bval_bvecs-np.array([0, 0, 1]))**2, axis=-1))

data_X = bval_data[..., axis_X]
data_Y = bval_data[..., axis_Y]
data_Z = bval_data[..., axis_Z]

mean_signal_X = np.mean(data_X[mask_cc_part])
mean_signal_Y = np.mean(data_Y[mask_cc_part])
mean_signal_Z = np.mean(data_Z[mask_cc_part])

cc_snr_worst[ind] = np.mean(mean_signal_X/std_signal)
cc_snr_best[ind] = np.mean(np.mean(mean_signal_Y, mean_signal_Z)/std_signal)

return cc_snr_worst, cc_snr_best


def get_spike_mask(data, z_threshold=3, grouping_vals=None, bmag=None):
"""
Return binary mask of spike/no spike
Parameters
----------
data : numpy array
Data to be thresholded
z_threshold : :obj:`float`
Number of standard deviations above the mean to use as spike threshold
grouping_vals : numpy array
Values by which to group data for thresholding (bvals or full mask)
bmag : int
From dipy.core.gradients:
The order of magnitude that the bvalues have to differ to be
considered an unique b-value. B-values are also rounded up to
this order of magnitude. Default: derive this value from the
maximal b-value provided: $bmag=log_{10}(max(bvals)) - 1$.
Returns
---------
numpy array
"""

if grouping_vals is None:
threshold = (z_threshold*np.std(data)) + np.mean(data)
spike_mask = data > threshold
return spike_mask

threshold_mask = np.zeros(data.shape)

rounded_grouping_vals = round_bvals(grouping_vals, bmag)
gvals = unique_bvals_magnitude(grouping_vals, bmag)

if grouping_vals.shape == data.shape:
for gval in gvals:
gval_data = data[rounded_grouping_vals == gval]
gval_threshold = (z_threshold*np.std(gval_data)) + np.mean(gval_data)
threshold_mask[rounded_grouping_vals == gval] = gval_threshold*np.ones(gval_data.shape)
else:
for gval in gvals:
gval_data = data[..., rounded_grouping_vals == gval]
gval_threshold = (z_threshold*np.std(gval_data)) + np.mean(gval_data)
threshold_mask[..., rounded_grouping_vals == gval] = gval_threshold*np.ones(gval_data.shape)

spike_mask = data > threshold_mask

return spike_mask


def get_slice_spike_percentage(data, z_threshold=3, slice_threshold=.05):
"""
Return percentage of slices spiking along each dimension
Parameters
----------
data : numpy array
Data to be thresholded
z_threshold : :obj:`float`
Number of standard deviations above the mean to use as spike threshold
slice_threshold : :obj:`float`
Percentage of slice elements that need to be above spike threshold for slice to be considered spiking
Returns
---------
array
"""
spike_mask = get_spike_mask(data, z_threshold)

ndim = data.ndim
slice_spike_percentage = np.zeros(ndim)

for ii in range(ndim):
slice_spike_percentage[ii] = np.mean(np.mean(spike_mask, ii) > slice_threshold)

return slice_spike_percentage


def get_global_spike_percentage(data, z_threshold=3):
"""
Return percentage of array elements spiking
Parameters
----------
data : numpy array
Data to be thresholded
z_threshold : :obj:`float`
Number of standard deviations above the mean to use as spike threshold
Returns
---------
float
"""
spike_mask = get_spike_mask(data, z_threshold)
global_spike_percentage = np.mean(np.ravel(spike_mask))

return global_spike_percentage

def noise_func_for_shelled_data(shelled_data, gtab):
pass
21 changes: 8 additions & 13 deletions mriqc/qc/tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
# https://www.nipreps.org/community/licensing/
#
import pytest
import os.path as op
import numpy as np
import nibabel as nib
from dipy.core.gradients import gradient_table
from dipy.data.fetcher import fetch_sherbrooke_3shell
from dipy.core.gradients import unique_bvals_magnitude, round_bvals
import os.path as op
from ..diffusion import noise_func, get_spike_mask, get_slice_spike_percentage, get_global_spike_percentage
from ..diffusion import noise_b0, noise_piesno


class DiffusionData(object):
def get_data(self):
Expand Down Expand Up @@ -61,6 +60,12 @@ def shelled_data(self):
def ddata():
return DiffusionData()


def test_noise_function(ddata):
img, gtab = ddata.get_fdata()
noise_func(img, gtab)


def test_get_spike_mask(ddata):
img, gtab = ddata.get_fdata()
spike_mask = get_spike_mask(img, 2)
Expand Down Expand Up @@ -89,14 +94,4 @@ def test_get_global_spike_percentage(ddata):

def test_with_shelled_data(ddata):
shelled_data, gtab = ddata.shelled_data()
noise_func_for_shelled_data(shelled_data, gtab)


def test_noise_b0(ddata):
data, gtab = ddata.get_data()
noise_b0(data, gtab)


def test_noise_piesno(ddata):
data, gtab = ddata.get_data()
noise_piesno(data)
noise_func_for_shelled_data(shelled_data, gtab)

0 comments on commit 68cb9b0

Please sign in to comment.