Skip to content

Commit

Permalink
fix: dwi spikes iqm test
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Apr 1, 2024
1 parent dddaea0 commit c61831b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 87 deletions.
2 changes: 1 addition & 1 deletion mriqc/qc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,4 @@ def spike_percentage(
for axis in range(data.ndim)
]

return {'spike_perc_global': spike_perc_global, 'spike_perc_slice': spike_perc_slice}
return {'global': spike_perc_global, 'slice': spike_perc_slice}
95 changes: 9 additions & 86 deletions mriqc/qc/tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,95 +21,18 @@
# https://www.nipreps.org/community/licensing/
#

import os.path as op

import nibabel as nib
import numpy as np
import pytest
from dipy.core.gradients import gradient_table, round_bvals
from dipy.data.fetcher import fetch_sherbrooke_3shell

from ..diffusion import (
cc_snr,
get_global_spike_percentage,
get_slice_spike_percentage,
get_spike_mask,
noise_func,
noise_func_for_shelled_data,
)


class DiffusionData:
def get_data(self):
"""
Generate test data
"""
_, path = fetch_sherbrooke_3shell()
fnifti = op.join(path, 'HARDI193.nii.gz')
fnifti, bval, bvec = (op.join(path, f'HARDI193.{ext}') for
ext in ['nii.gz', 'bval', 'bvec'])
img = nib.load(fnifti)
data = img.get_fdata()
gtab = gradient_table(bval, bvec)
return data, gtab

def shelled_data(self):
data, gtab = self.get_data()
rounded_bvals = round_bvals(gtab.bvals)
unique_rounded_bvals = np.unique(rounded_bvals)

out_data = []
for u_bv in unique_rounded_bvals:
this = data[..., np.where(rounded_bvals == u_bv)]
out_data.append(this)
return out_data, gtab


@pytest.fixture()
def ddata():
return DiffusionData()


def test_noise_function(ddata):
img, gtab = ddata.get_fdata()
noise_func(img, gtab)


def test_get_spike_mask(ddata):
img, gtab = ddata.get_fdata()
spike_mask = get_spike_mask(img, 2)

assert np.min(np.ravel(spike_mask)) == 0
assert np.max(np.ravel(spike_mask)) == 1
assert spike_mask.shape == img.shape


def test_get_slice_spike_percentage(ddata):
img, gtab = ddata.get_fdata()
slice_spike_percentage = get_slice_spike_percentage(img, 2, .2)

assert np.min(slice_spike_percentage) >= 0
assert np.max(slice_spike_percentage) <= 1
assert len(slice_spike_percentage) == img.ndim


def test_get_global_spike_percentage(ddata):
img, gtab = ddata.get_fdata()
global_spike_percentage = get_global_spike_percentage(img, 2)

assert global_spike_percentage >= 0
assert global_spike_percentage <= 1

from mriqc.qc.diffusion import spike_percentage

def test_with_shelled_data(ddata):
shelled_data, gtab = ddata.shelled_data()
noise_func_for_shelled_data(shelled_data, gtab)

def test_spike_percentage():
img = np.random.normal(loc=10, scale=1.0, size=(76, 76, 64, 124))
msk = np.random.randint(0, high=2, size=(76, 76, 64, 124), dtype=bool)
val = spike_percentage(img, msk, .5)

def test_cc_snr(ddata):
img, gtab = ddata.get_fdata()
cc_snr_worst, cc_snr_best = cc_snr(img, gtab)
assert np.isclose(val['global'], 0.5, rtol=1, atol=1)

assert cc_snr_best.shape == gtab.bvals.shape
assert cc_snr_worst.shape == gtab.bvals.shape
assert np.min(cc_snr_best - cc_snr_worst) >= 0
assert np.min(val['slice']) >= 0
assert np.max(val['slice']) == 1
assert len(val['slice']) == img.ndim

0 comments on commit c61831b

Please sign in to comment.