Skip to content

Commit

Permalink
Smaps update and prevent saving trajectory for autodiff (#156)
Browse files Browse the repository at this point in the history
* cufinufft and gpunufft smaps support update

* Add smaps update tests for cufinufft and gpunuft

* No need to save traj variable if differenciation wrt trajectory is not performed

* add the case when differenciation is wrt to trajectory and not data to autodiff

* Corrections for PR comments

* Changes for ruff checks

* Get rid of saving trajectory

* Fix data

* Update tests/operators/test_update.py

* Update tests/operators/test_update.py

Co-authored-by: Pierre-Antoine Comby <[email protected]>

* Run examples on tests on nodes

* Update test-ci.yml

* Update tests/operators/test_update.py

Co-authored-by: Pierre-Antoine Comby <[email protected]>

* Update tests/operators/test_update.py

Co-authored-by: Pierre-Antoine Comby <[email protected]>

* Update tests/operators/test_update.py

Co-authored-by: Pierre-Antoine Comby <[email protected]>

* Added autodiff changes as per suggestions

* fix indent

* revert

* Fixed

* add shape check in smaps setters

* Use direct self.smaps

* Fixes in codes and

* Call super

* All fixed hopefully

* Fixed cufinufft

* Fixed source

* Fixed cufinufft

* Added new operator

* Add checks after setting smaps

* Update

* Remove unwanted print

* Skip testing with tensorflow

* Move to check shapes and stop installing torch

* Black

* Fix style

* remove

* remove

* restart

* Fix

* Remove tensorlfow

* undo remove

---------

Co-authored-by: Chaithya G R <[email protected]>
Co-authored-by: Pierre-Antoine Comby <[email protected]>
  • Loading branch information
3 people authored Jul 25, 2024
1 parent ff82504 commit 4236497
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 62 deletions.
15 changes: 11 additions & 4 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ env:

jobs:
test-cpu:
runs-on: ubuntu-latest
runs-on: cpu
if: ${{ !contains(github.event.head_commit.message, 'style')}}
strategy:
matrix:
Expand Down Expand Up @@ -53,7 +53,6 @@ jobs:
if: ${{ matrix.backend == 'pynfft' || env.ref_backend == 'pynfft' }}
shell: bash
run: |
sudo apt install -y libnfft3-dev
python -m pip install "pynfft2>=1.4.3"
- name: Install pynufft
Expand Down Expand Up @@ -104,6 +103,9 @@ jobs:
strategy:
matrix:
backend: [cufinufft, gpunufft, torchkbnufft-gpu, tensorflow]
exclude:
# There is an issue with tensorflow and cupy. This was working. See #156
- backend: tensorflow

steps:
- uses: actions/checkout@v3
Expand All @@ -118,9 +120,14 @@ jobs:
pip install --upgrade pip wheel
pip install -e mri-nufft[test]
pip install cupy-cuda11x
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install finufft "numpy<2.0"
- name: Install torch with CUDA 11.8
shell: bash
if: ${{ matrix.backend != 'tensorflow'}}
run: |
source $RUNNER_WORKSPACE/venv/bin/activate
pip install torch --index-url https://download.pytorch.org/whl/cu118
- name: Install backend
shell: bash
Expand Down Expand Up @@ -163,7 +170,7 @@ jobs:
rm -rf venv
test-examples:
runs-on: ubuntu-latest
runs-on: cpu
if: ${{ !contains(github.event.head_commit.message, 'style')}}

steps:
Expand Down
13 changes: 7 additions & 6 deletions src/mrinufft/operators/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import numpy as np
from .._utils import NP2TORCH


class _NUFFT_OP(torch.autograd.Function):
Expand All @@ -20,14 +21,14 @@ class _NUFFT_OP(torch.autograd.Function):
@staticmethod
def forward(ctx, x, traj, nufft_op):
"""Forward image -> k-space."""
ctx.save_for_backward(x, traj)
ctx.save_for_backward(x)
ctx.nufft_op = nufft_op
return nufft_op.op(x)

@staticmethod
def backward(ctx, dy):
"""Backward image -> k-space."""
(x, traj) = ctx.saved_tensors
x = ctx.saved_tensors[0]
grad_data = None
grad_traj = None
if ctx.nufft_op._grad_wrt_data:
Expand Down Expand Up @@ -62,7 +63,7 @@ def backward(ctx, dy):
dim=0,
),
dim=0,
).type_as(traj)
).to(NP2TORCH[ctx.nufft_op.dtype])
return grad_data, grad_traj, None


Expand All @@ -72,14 +73,14 @@ class _NUFFT_ADJOP(torch.autograd.Function):
@staticmethod
def forward(ctx, y, traj, nufft_op):
"""Forward kspace -> image."""
ctx.save_for_backward(y, traj)
ctx.save_for_backward(y)
ctx.nufft_op = nufft_op
return nufft_op.adj_op(y)

@staticmethod
def backward(ctx, dx):
"""Backward kspace -> image."""
(y, traj) = ctx.saved_tensors
y = ctx.saved_tensors[0]
grad_data = None
grad_traj = None
if ctx.nufft_op._grad_wrt_data:
Expand Down Expand Up @@ -111,7 +112,7 @@ def backward(ctx, dx):
dim=0,
),
dim=0,
).type_as(traj)
).to(NP2TORCH[ctx.nufft_op.dtype])
ctx.nufft_op.raw_op.toggle_grad_traj()
return grad_data, grad_traj, None

Expand Down
17 changes: 10 additions & 7 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,10 @@ def compute_smaps(self, method=None):
"""
if isinstance(method, np.ndarray):
self.smaps = method
return None
return
if not method:
self.smaps = None
return None
return
kwargs = {}
if isinstance(method, dict):
kwargs = method.copy()
Expand Down Expand Up @@ -501,15 +501,18 @@ def smaps(self):

@smaps.setter
def smaps(self, smaps):
self._check_smaps_shape(smaps)
self._smaps = smaps

def _check_smaps_shape(self, smaps):
"""Check the shape of the sensitivity maps."""
if smaps is None:
self._smaps = None
elif len(smaps) != self.n_coils:
elif smaps.shape != (self.n_coils, *self.shape):
raise ValueError(
f"Number of sensitivity maps ({len(smaps)})"
f"should be equal to n_coils ({self.n_coils})"
f"smaps shape is {smaps.shape}, it should be"
f"(n_coils, *shape): {(self.n_coils, *self.shape)}"
)
else:
self._smaps = smaps

@property
def density(self):
Expand Down
61 changes: 39 additions & 22 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,35 +221,52 @@ def __init__(
if is_host_array(self.density):
self.density = cp.array(self.density)

# Smaps support
self.smaps_cached = smaps_cached
self.compute_smaps(smaps)
self.smaps_cached = False
if smaps is not None:
if not (is_host_array(smaps) or is_cuda_array(smaps)):
raise ValueError(
"Smaps should be either a C-ordered ndarray, " "or a GPUArray."
)
self.smaps_cached = False
if smaps_cached:
warnings.warn(
f"{sizeof_fmt(smaps.size * np.dtype(self.cpx_dtype).itemsize)}"
"used on gpu for smaps."
)
self.smaps = cp.array(
smaps, order="C", copy=False, dtype=self.cpx_dtype
)
self.smaps_cached = True
else:
self.smaps = pin_memory(smaps.astype(self.cpx_dtype, copy=False))
self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype)

# Smaps support
if self.smaps is not None and (
not (is_host_array(self.smaps) or is_cuda_array(self.smaps))
):
raise ValueError(
"Smaps should be either a C-ordered np.ndarray, or a GPUArray."
)
self.raw_op = RawCufinufftPlan(
self.samples,
tuple(shape),
n_trans=n_trans,
**kwargs,
)
# Support for concurrent stream and computations.

@FourierOperatorBase.smaps.setter
def smaps(self, new_smaps):
"""Update smaps.
Parameters
----------
new_smaps: C-ordered ndarray or a GPUArray.
"""
self._check_smaps_shape(new_smaps)
if new_smaps is not None and hasattr(self, "smaps_cached"):
if self.smaps_cached:
warnings.warn(
f"{sizeof_fmt(new_smaps.size * np.dtype(self.cpx_dtype).itemsize)}"
"used on gpu for smaps."
)
self._smaps = cp.array(
new_smaps, order="C", copy=False, dtype=self.cpx_dtype
)
else:
if self._smaps is None:
self._smaps = pin_memory(
new_smaps.astype(self.cpx_dtype, copy=False)
)
self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype)
else:
# copy the array to pinned memory
np.copyto(self._smaps, new_smaps.astype(self.cpx_dtype, copy=False))
else:
self._smaps = new_smaps

@FourierOperatorBase.samples.setter
def samples(self, samples):
Expand Down
26 changes: 26 additions & 0 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ def _reshape_image(self, image, direction="op"):
return image.squeeze().astype(xp.complex64, copy=False).T[None]
return xp.asarray([c.T for c in image], dtype=xp.complex64).squeeze()

def set_smaps(self, smaps):
"""Update the smaps.
Parameters
----------
smaps: np.ndarray[np.complex64])
sensittivity maps
"""
smaps_ = smaps.T.reshape(-1, smaps.shape[0])
np.copyto(self.pinned_smaps, smaps_)

def set_pts(self, samples, density=None):
"""Update the kspace locations and density compensation.
Expand Down Expand Up @@ -502,6 +513,21 @@ def uses_sense(self):
"""Return True if the Fourier Operator uses the SENSE method."""
return self.raw_op.uses_sense

@FourierOperatorBase.smaps.setter
def smaps(self, new_smaps):
"""Update pinned smaps from new_smaps.
Parameters
----------
new_smaps: np.ndarray
the new sensitivity maps
"""
self._check_smaps_shape(new_smaps)
self._smaps = new_smaps
if self._smaps is not None and hasattr(self, "raw_op"):
self.raw_op.set_smaps(smaps=new_smaps)

@FourierOperatorBase.samples.setter
def samples(self, samples):
"""Set the samples for the Fourier Operator.
Expand Down
2 changes: 2 additions & 0 deletions tests/helpers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def wrapper(operator, array_interface, *args, **kwargs):
if array_interface in ["torch-cpu", "numpy"]:
if op.backend in ["torchkbnufft-gpu"]:
pytest.skip("Uncompatible backend and array")
if "torch" in array_interface and op.backend in ["tensorflow"]:
pytest.skip("Uncompatible backend and array")
return func(operator, array_interface, *args, **kwargs)

return _param_array_interface(wrapper)
2 changes: 0 additions & 2 deletions tests/operators/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def test_interfaces_autoadjoint(operator, array_interface):
from_interface(image, array_interface),
)
reldiff[i] = abs(rightadjoint - leftadjoint) / abs(leftadjoint)
print(reldiff)
assert np.mean(reldiff) < 5e-5


Expand All @@ -148,5 +147,4 @@ def AHA(x):
L[i] = np.linalg.norm(AHA(img2_data) - AHA(img_data)) / np.linalg.norm(
img2_data - img_data
)

assert np.mean(L) < 1.1 * spec_rad
Loading

0 comments on commit 4236497

Please sign in to comment.