From 423649783bf742050f51b1ccf91e4942e6eaaa29 Mon Sep 17 00:00:00 2001 From: Asma TANABENE <121893894+AsmaTANABEN@users.noreply.github.com> Date: Thu, 25 Jul 2024 16:34:58 +0200 Subject: [PATCH] Smaps update and prevent saving trajectory for autodiff (#156) * cufinufft and gpunufft smaps support update * Add smaps update tests for cufinufft and gpunuft * No need to save traj variable if differenciation wrt trajectory is not performed * add the case when differenciation is wrt to trajectory and not data to autodiff * Corrections for PR comments * Changes for ruff checks * Get rid of saving trajectory * Fix data * Update tests/operators/test_update.py * Update tests/operators/test_update.py Co-authored-by: Pierre-Antoine Comby * Run examples on tests on nodes * Update test-ci.yml * Update tests/operators/test_update.py Co-authored-by: Pierre-Antoine Comby * Update tests/operators/test_update.py Co-authored-by: Pierre-Antoine Comby * Update tests/operators/test_update.py Co-authored-by: Pierre-Antoine Comby * Added autodiff changes as per suggestions * fix indent * revert * Fixed * add shape check in smaps setters * Use direct self.smaps * Fixes in codes and * Call super * All fixed hopefully * Fixed cufinufft * Fixed source * Fixed cufinufft * Added new operator * Add checks after setting smaps * Update * Remove unwanted print * Skip testing with tensorflow * Move to check shapes and stop installing torch * Black * Fix style * remove * remove * restart * Fix * Remove tensorlfow * undo remove --------- Co-authored-by: Chaithya G R Co-authored-by: Pierre-Antoine Comby --- .github/workflows/test-ci.yml | 15 ++- src/mrinufft/operators/autodiff.py | 13 +-- src/mrinufft/operators/base.py | 17 +-- .../operators/interfaces/cufinufft.py | 61 +++++++---- src/mrinufft/operators/interfaces/gpunufft.py | 26 +++++ tests/helpers/factories.py | 2 + tests/operators/test_interfaces.py | 2 - tests/operators/test_update.py | 100 ++++++++++++++---- 8 files changed, 174 insertions(+), 62 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 3b03db49..e966b409 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,7 +20,7 @@ env: jobs: test-cpu: - runs-on: ubuntu-latest + runs-on: cpu if: ${{ !contains(github.event.head_commit.message, 'style')}} strategy: matrix: @@ -53,7 +53,6 @@ jobs: if: ${{ matrix.backend == 'pynfft' || env.ref_backend == 'pynfft' }} shell: bash run: | - sudo apt install -y libnfft3-dev python -m pip install "pynfft2>=1.4.3" - name: Install pynufft @@ -104,6 +103,9 @@ jobs: strategy: matrix: backend: [cufinufft, gpunufft, torchkbnufft-gpu, tensorflow] + exclude: + # There is an issue with tensorflow and cupy. This was working. See #156 + - backend: tensorflow steps: - uses: actions/checkout@v3 @@ -118,9 +120,14 @@ jobs: pip install --upgrade pip wheel pip install -e mri-nufft[test] pip install cupy-cuda11x - pip install torch --index-url https://download.pytorch.org/whl/cu118 pip install finufft "numpy<2.0" + - name: Install torch with CUDA 11.8 + shell: bash + if: ${{ matrix.backend != 'tensorflow'}} + run: | + source $RUNNER_WORKSPACE/venv/bin/activate + pip install torch --index-url https://download.pytorch.org/whl/cu118 - name: Install backend shell: bash @@ -163,7 +170,7 @@ jobs: rm -rf venv test-examples: - runs-on: ubuntu-latest + runs-on: cpu if: ${{ !contains(github.event.head_commit.message, 'style')}} steps: diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 7b1c5220..cf7a7b0b 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -2,6 +2,7 @@ import torch import numpy as np +from .._utils import NP2TORCH class _NUFFT_OP(torch.autograd.Function): @@ -20,14 +21,14 @@ class _NUFFT_OP(torch.autograd.Function): @staticmethod def forward(ctx, x, traj, nufft_op): """Forward image -> k-space.""" - ctx.save_for_backward(x, traj) + ctx.save_for_backward(x) ctx.nufft_op = nufft_op return nufft_op.op(x) @staticmethod def backward(ctx, dy): """Backward image -> k-space.""" - (x, traj) = ctx.saved_tensors + x = ctx.saved_tensors[0] grad_data = None grad_traj = None if ctx.nufft_op._grad_wrt_data: @@ -62,7 +63,7 @@ def backward(ctx, dy): dim=0, ), dim=0, - ).type_as(traj) + ).to(NP2TORCH[ctx.nufft_op.dtype]) return grad_data, grad_traj, None @@ -72,14 +73,14 @@ class _NUFFT_ADJOP(torch.autograd.Function): @staticmethod def forward(ctx, y, traj, nufft_op): """Forward kspace -> image.""" - ctx.save_for_backward(y, traj) + ctx.save_for_backward(y) ctx.nufft_op = nufft_op return nufft_op.adj_op(y) @staticmethod def backward(ctx, dx): """Backward kspace -> image.""" - (y, traj) = ctx.saved_tensors + y = ctx.saved_tensors[0] grad_data = None grad_traj = None if ctx.nufft_op._grad_wrt_data: @@ -111,7 +112,7 @@ def backward(ctx, dx): dim=0, ), dim=0, - ).type_as(traj) + ).to(NP2TORCH[ctx.nufft_op.dtype]) ctx.nufft_op.raw_op.toggle_grad_traj() return grad_data, grad_traj, None diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 986b76ae..ff4c2c46 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -337,10 +337,10 @@ def compute_smaps(self, method=None): """ if isinstance(method, np.ndarray): self.smaps = method - return None + return if not method: self.smaps = None - return None + return kwargs = {} if isinstance(method, dict): kwargs = method.copy() @@ -501,15 +501,18 @@ def smaps(self): @smaps.setter def smaps(self, smaps): + self._check_smaps_shape(smaps) + self._smaps = smaps + + def _check_smaps_shape(self, smaps): + """Check the shape of the sensitivity maps.""" if smaps is None: self._smaps = None - elif len(smaps) != self.n_coils: + elif smaps.shape != (self.n_coils, *self.shape): raise ValueError( - f"Number of sensitivity maps ({len(smaps)})" - f"should be equal to n_coils ({self.n_coils})" + f"smaps shape is {smaps.shape}, it should be" + f"(n_coils, *shape): {(self.n_coils, *self.shape)}" ) - else: - self._smaps = smaps @property def density(self): diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index c32851f6..63baa185 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -221,35 +221,52 @@ def __init__( if is_host_array(self.density): self.density = cp.array(self.density) - # Smaps support + self.smaps_cached = smaps_cached self.compute_smaps(smaps) - self.smaps_cached = False - if smaps is not None: - if not (is_host_array(smaps) or is_cuda_array(smaps)): - raise ValueError( - "Smaps should be either a C-ordered ndarray, " "or a GPUArray." - ) - self.smaps_cached = False - if smaps_cached: - warnings.warn( - f"{sizeof_fmt(smaps.size * np.dtype(self.cpx_dtype).itemsize)}" - "used on gpu for smaps." - ) - self.smaps = cp.array( - smaps, order="C", copy=False, dtype=self.cpx_dtype - ) - self.smaps_cached = True - else: - self.smaps = pin_memory(smaps.astype(self.cpx_dtype, copy=False)) - self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) - + # Smaps support + if self.smaps is not None and ( + not (is_host_array(self.smaps) or is_cuda_array(self.smaps)) + ): + raise ValueError( + "Smaps should be either a C-ordered np.ndarray, or a GPUArray." + ) self.raw_op = RawCufinufftPlan( self.samples, tuple(shape), n_trans=n_trans, **kwargs, ) - # Support for concurrent stream and computations. + + @FourierOperatorBase.smaps.setter + def smaps(self, new_smaps): + """Update smaps. + + Parameters + ---------- + new_smaps: C-ordered ndarray or a GPUArray. + + """ + self._check_smaps_shape(new_smaps) + if new_smaps is not None and hasattr(self, "smaps_cached"): + if self.smaps_cached: + warnings.warn( + f"{sizeof_fmt(new_smaps.size * np.dtype(self.cpx_dtype).itemsize)}" + "used on gpu for smaps." + ) + self._smaps = cp.array( + new_smaps, order="C", copy=False, dtype=self.cpx_dtype + ) + else: + if self._smaps is None: + self._smaps = pin_memory( + new_smaps.astype(self.cpx_dtype, copy=False) + ) + self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) + else: + # copy the array to pinned memory + np.copyto(self._smaps, new_smaps.astype(self.cpx_dtype, copy=False)) + else: + self._smaps = new_smaps @FourierOperatorBase.samples.setter def samples(self, samples): diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index e19a2f8c..6b5853f4 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -202,6 +202,17 @@ def _reshape_image(self, image, direction="op"): return image.squeeze().astype(xp.complex64, copy=False).T[None] return xp.asarray([c.T for c in image], dtype=xp.complex64).squeeze() + def set_smaps(self, smaps): + """Update the smaps. + + Parameters + ---------- + smaps: np.ndarray[np.complex64]) + sensittivity maps + """ + smaps_ = smaps.T.reshape(-1, smaps.shape[0]) + np.copyto(self.pinned_smaps, smaps_) + def set_pts(self, samples, density=None): """Update the kspace locations and density compensation. @@ -502,6 +513,21 @@ def uses_sense(self): """Return True if the Fourier Operator uses the SENSE method.""" return self.raw_op.uses_sense + @FourierOperatorBase.smaps.setter + def smaps(self, new_smaps): + """Update pinned smaps from new_smaps. + + Parameters + ---------- + new_smaps: np.ndarray + the new sensitivity maps + + """ + self._check_smaps_shape(new_smaps) + self._smaps = new_smaps + if self._smaps is not None and hasattr(self, "raw_op"): + self.raw_op.set_smaps(smaps=new_smaps) + @FourierOperatorBase.samples.setter def samples(self, samples): """Set the samples for the Fourier Operator. diff --git a/tests/helpers/factories.py b/tests/helpers/factories.py index 7d05b685..81defbae 100644 --- a/tests/helpers/factories.py +++ b/tests/helpers/factories.py @@ -107,6 +107,8 @@ def wrapper(operator, array_interface, *args, **kwargs): if array_interface in ["torch-cpu", "numpy"]: if op.backend in ["torchkbnufft-gpu"]: pytest.skip("Uncompatible backend and array") + if "torch" in array_interface and op.backend in ["tensorflow"]: + pytest.skip("Uncompatible backend and array") return func(operator, array_interface, *args, **kwargs) return _param_array_interface(wrapper) diff --git a/tests/operators/test_interfaces.py b/tests/operators/test_interfaces.py index d48e73a5..3769063c 100644 --- a/tests/operators/test_interfaces.py +++ b/tests/operators/test_interfaces.py @@ -129,7 +129,6 @@ def test_interfaces_autoadjoint(operator, array_interface): from_interface(image, array_interface), ) reldiff[i] = abs(rightadjoint - leftadjoint) / abs(leftadjoint) - print(reldiff) assert np.mean(reldiff) < 5e-5 @@ -148,5 +147,4 @@ def AHA(x): L[i] = np.linalg.norm(AHA(img2_data) - AHA(img_data)) / np.linalg.norm( img2_data - img_data ) - assert np.mean(L) < 1.1 * spec_rad diff --git a/tests/operators/test_update.py b/tests/operators/test_update.py index 3c7b797d..0ced6cbb 100644 --- a/tests/operators/test_update.py +++ b/tests/operators/test_update.py @@ -1,4 +1,4 @@ -"""Test for update in trajectory and density. +"""Test for update in trajectory, density, and sensitivity maps. Only finufft, cufinufft and gpunufft support update. """ @@ -39,6 +39,7 @@ ) @parametrize(backend=["finufft", "cufinufft", "gpunufft"]) @parametrize(density=[False, True]) +@parametrize(smaps_cached=[False, True]) def operator( request, kspace_locs, @@ -49,6 +50,7 @@ def operator( n_trans=1, density=False, backend="finufft", + smaps_cached=False, ): """Generate a batch operator.""" if n_trans != 1 and backend == "gpunufft": @@ -63,29 +65,35 @@ def operator( else: smaps = None kspace_locs = kspace_locs.astype(np.float32) - return get_operator(backend)( - kspace_locs, - shape, - n_coils=n_coils, - smaps=smaps, - n_batchs=n_batch, - n_trans=n_trans, - squeeze_dims=False, - density=density, - ) + op_args = { + "samples": kspace_locs, + "shape": shape, + "n_coils": n_coils, + "n_batchs": n_batch, + "squeeze_dims": False, + "density": density, + "smaps": smaps, + } + if backend in ["cufinufft"]: + op_args["smaps_cached"] = smaps_cached + else: + if smaps_cached: + pytest.skip(f"Skip test cause we dont have smaps_cached in {backend}") + return get_operator(backend)(**op_args) def update_operator(operator): """Generate a new operator with updated trajectory.""" - return get_operator(operator.backend)( - operator.samples, - operator.shape, - density=operator.density, - n_coils=operator.n_coils, - smaps=operator.smaps, - squeeze_dims=False, - n_batchs=operator.n_batchs, - ) + op_args = { + k: getattr(operator, k) + for k in ["samples", "shape", "n_coils", "n_batchs", "density", "smaps"] + } + op_args["squeeze_dims"] = False + if operator.backend == "cufinufft": + op_args["smaps_cached"] = operator.smaps_cached + if operator.smaps is not None and not isinstance(operator.smaps, np.ndarray): + op_args["smaps"] = operator.smaps.get() + return get_operator(operator.backend)(**op_args) @fixture(scope="module") @@ -109,8 +117,22 @@ def kspace_data(operator): return kspace +@fixture(scope="module") +def new_smaps(operator): + """Generate a random new smaps.""" + smaps = 1j * np.random.rand(operator.n_coils, *operator.shape) + smaps += np.random.rand(operator.n_coils, *operator.shape) + smaps = smaps.astype(np.complex64) + smaps /= np.linalg.norm(smaps, axis=0) + return smaps + + @param_array_interface -def test_op(operator, array_interface, image_data): +def test_op( + operator, + array_interface, + image_data, +): """Test the batch type 2 (forward).""" image_data = to_interface(image_data, array_interface) jitter = np.random.rand(*operator.samples.shape).astype(np.float32) @@ -138,3 +160,39 @@ def test_adj_op( image_true = from_interface(new_operator.adj_op(kspace_data), array_interface) # Reduced accuracy for the GPU cases... npt.assert_allclose(image_changed, image_true, atol=1e-3, rtol=1e-3) + + +@param_array_interface +def test_op_smaps_update( + operator, + array_interface, + image_data, + new_smaps, +): + """Test the batch type 2 (forward) with smaps.""" + image_data = to_interface(image_data, array_interface) + if operator.smaps is None: + pytest.skip("Skipping as we dont have smaps") + operator.smaps = new_smaps + new_operator = update_operator(operator) + kspace_changed = from_interface(operator.op(image_data), array_interface) + kspace_true = from_interface(new_operator.op(image_data), array_interface) + npt.assert_array_almost_equal(kspace_changed, kspace_true) + + +@param_array_interface +def test_adj_op_smaps_update( + operator, + array_interface, + kspace_data, + new_smaps, +): + """Test the batch type 1 (adjoint).""" + kspace_data = to_interface(kspace_data, array_interface) + if operator.smaps is None: + pytest.skip("Skipping as we dont have smaps") + operator.smaps = new_smaps + new_operator = update_operator(operator) + image_changed = from_interface(operator.adj_op(kspace_data), array_interface) + image_true = from_interface(new_operator.adj_op(kspace_data), array_interface) + npt.assert_allclose(image_changed, image_true, atol=1e-4, rtol=1e-4)