diff --git a/.github/workflows/master-cd.yml b/.github/workflows/master-cd.yml index 5e2d22c4..f8d34e32 100644 --- a/.github/workflows/master-cd.yml +++ b/.github/workflows/master-cd.yml @@ -29,42 +29,19 @@ jobs: if: success() steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Get history and tags for SCM versioning to work - run: | - git fetch --prune --unshallow - git fetch --depth=1 origin +refs/tags/*:refs/tags/* - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - name: Install dependencies - shell: bash -l {0} - run: | - python -m pip install --upgrade pip - python -m pip install .[doc] - python -m pip install finufft - - - name: Build API documentation - run: | - python -m sphinx docs docs_build - - - name: Get the badge from CI + - name: Get the docs_build artifact uses: actions/download-artifact@v4 with: - name: coverage_badge - path: docs_build/_static - github-token: ${{ secrets.GITHUB_TOKEN }} + name: docs_final + path: docs_build run-id: ${{ github.event.workflow_run.id }} + + - name: Display structure of docs + run: ls -R docs_build/ - - name: Display data - run: ls -R - working-directory: docs_build/_static - name: Upload artifact uses: actions/upload-pages-artifact@v1 with: - # Upload entire repository path: 'docs_build' - name: Deploy to GitHub Pages diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 85ca56c2..96f9a1b2 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -1,4 +1,4 @@ -name: Style checking +name: Style on: push: diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index e966b409..f05d1476 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -170,7 +170,7 @@ jobs: rm -rf venv test-examples: - runs-on: cpu + runs-on: gpu if: ${{ !contains(github.event.head_commit.message, 'style')}} steps: @@ -194,6 +194,15 @@ jobs: python -m pip install --upgrade pip python -m pip install -e .[test,dev] python -m pip install finufft pooch brainweb-dl torch + + - name: Install GPU related interfaces + run: | + export CUDA_BIN_PATH=/usr/local/cuda-11.8/ + export PATH=/usr/local/cuda-11.8/bin/:${PATH} + export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64/:${LD_LIBRARY_PATH} + pip install cupy-cuda11x + pip install torch --index-url https://download.pytorch.org/whl/cu118 + python -m pip install gpuNUFFT cufinufft - name: Run examples shell: bash @@ -268,3 +277,80 @@ jobs: with: name: coverage_badge path: coverage_badge.svg + + BuildDocs: + name: Build API Documentation + runs-on: gpu + if: ${{ contains(github.event.head_commit.message, 'docs_build')}} or ${{github.ref == 'refs/heads/master'}} + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Get history and tags for SCM versioning to work + run: | + git fetch --prune --unshallow + git fetch --depth=1 origin +refs/tags/*:refs/tags/* + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install dependencies + shell: bash -l {0} + run: | + python -m pip install --upgrade pip + python -m pip install .[doc] + python -m pip install finufft + + - name: Install GPU related interfaces + run: | + export CUDA_BIN_PATH=/usr/local/cuda-11.8/ + export PATH=/usr/local/cuda-11.8/bin/:${PATH} + export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64/:${LD_LIBRARY_PATH} + pip install cupy-cuda11x + pip install torch --index-url https://download.pytorch.org/whl/cu118 + python -m pip install gpuNUFFT cufinufft + + - name: Build API documentation + run: | + python -m sphinx docs docs_build + + - name: Display data + run: ls -R + working-directory: docs_build/_static + + - name: Upload artifact + id: artifact-upload-step + uses: actions/upload-artifact@v4 + with: + # Upload the docs + name: docs + path: 'docs_build' + retention-days: 5 + + CompileDocs: + name: Compile the coverage badge in docs + runs-on: ubuntu-latest + if: ${{ github.ref == 'refs/heads/master' }} + needs: [BuildDocs, coverage] + steps: + - name: Get the docs_build artifact + uses: actions/download-artifact@v4 + with: + name: docs + path: docs_build + + - name: Get the badge from CI + uses: actions/download-artifact@v4 + with: + name: coverage_badge + path: docs_build/_static + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: ReUpload artifacts + uses: actions/upload-artifact@v4 + with: + name: docs_final + retention-days: 20 + path: docs_build diff --git a/docs/conf.py b/docs/conf.py index 04d9a3c2..f163797d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,8 +77,9 @@ "examples_dirs": ["../examples/"], "gallery_dirs": ["generated/autoexamples"], "filename_pattern": "/example_", - "ignore_pattern": r"/(__init__|conftest|utils)\.py", + "ignore_pattern": r"(__init__|conftest|utils).py", "nested_sections": True, + "parallel": 3, } intersphinx_mapping = { diff --git a/examples/GPU/README.rst b/examples/GPU/README.rst new file mode 100644 index 00000000..1b87b9ce --- /dev/null +++ b/examples/GPU/README.rst @@ -0,0 +1,6 @@ +.. _gpu_examples: + +GPU Examples +------------ + +This is a collection of examples showing features of mri-nufft, particularly those that are GPU-accelerated. diff --git a/examples/example_density.py b/examples/GPU/example_density.py similarity index 87% rename from examples/example_density.py rename to examples/GPU/example_density.py index 20f45b3c..8c84f47b 100644 --- a/examples/example_density.py +++ b/examples/GPU/example_density.py @@ -23,7 +23,7 @@ # Create sample data # ------------------ -mri_2D = bwdl.get_mri(4, "T1")[80, ...].astype(np.float32) +mri_2D = np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.float32) print(mri_2D.shape) @@ -136,18 +136,16 @@ # The Pipe method is currently only implemented for gpuNUFFT. # %% -if check_backend("gpunufft"): - flat_traj = traj.reshape(-1, 2) - nufft = get_operator("gpunufft")( - traj, shape=mri_2D.shape, density={"name": "pipe", "osf": 2} - ) - adjoint_manual = nufft.adj_op(kspace) - fig, axs = plt.subplots(1, 3, figsize=(15, 5)) - axs[0].imshow(abs(mri_2D)) - axs[0].set_title("Ground Truth") - axs[1].imshow(abs(adjoint)) - axs[1].set_title("no density compensation") - axs[2].imshow(abs(adjoint_manual)) - axs[2].set_title("Pipe density compensation") - - print(nufft.density) +flat_traj = traj.reshape(-1, 2) +nufft = get_operator("gpunufft")( + traj, shape=mri_2D.shape, density={"name": "pipe", "osf": 2} +) +adjoint_manual = nufft.adj_op(kspace) +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) +axs[0].imshow(abs(mri_2D)) +axs[0].set_title("Ground Truth") +axs[1].imshow(abs(adjoint)) +axs[1].set_title("no density compensation") +axs[2].imshow(abs(adjoint_manual)) +axs[2].set_title("Pipe density compensation") +print(nufft.density) diff --git a/examples/GPU/example_learn_samples.py b/examples/GPU/example_learn_samples.py new file mode 100644 index 00000000..991539c0 --- /dev/null +++ b/examples/GPU/example_learn_samples.py @@ -0,0 +1,184 @@ +# %% +""" +====================== +Learn Sampling pattern +====================== + +A small pytorch example to showcase learning k-space sampling patterns. +This example showcases the auto-diff capabilities of the NUFFT operator +wrt to k-space trajectory in mri-nufft. + +.. warning:: + This example only showcases the autodiff capabilities, the learned sampling pattern is not scanner compliant as the scanner gradients required to implement it violate the hardware constraints. In practice, a projection into the scanner constraints set is recommended. This is implemented in the proprietary SPARKLING package. Users are encouraged to contact the authors if they want to use it. +""" +import brainweb_dl as bwdl +import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import tqdm +from PIL import Image, ImageSequence + +from mrinufft import get_operator +from mrinufft.trajectories import initialize_2D_radial + + +# %% +# Setup a simple class to learn trajectory +# ---------------------------------------- + + +class Model(torch.nn.Module): + def __init__(self, inital_trajectory): + super(Model, self).__init__() + self.trajectory = torch.nn.Parameter( + data=torch.Tensor(inital_trajectory), + requires_grad=True, + ) + self.operator = get_operator("gpunufft", wrt_data=True, wrt_traj=True)( + self.trajectory.detach().cpu().numpy(), + shape=(256, 256), + density=True, + squeeze_dims=False, + ) + + def forward(self, x): + self.operator.samples = self.trajectory.clone() + kspace = self.operator.op(x) + adjoint = self.operator.adj_op(kspace) + return adjoint / torch.linalg.norm(adjoint) + + +# %% +# Util function to plot the state of the model +# -------------------------------------------- + + +def plot_state(axs, mri_2D, traj, recon, loss=None, save_dir="/tmp/", save_name=None): + axs = axs.flatten() + axs[0].imshow(np.abs(mri_2D[0]), cmap="gray") + axs[0].axis("off") + axs[0].set_title("MR Image") + axs[1].scatter(*traj.T, s=1) + axs[1].set_title("Trajectory") + axs[2].imshow(np.abs(recon[0][0].detach().cpu().numpy()), cmap="gray") + axs[2].axis("off") + axs[2].set_title("Reconstruction") + if loss is not None: + axs[3].plot(loss) + axs[3].set_title("Loss") + if save_name is not None: + plt.savefig(save_dir + save_name, bbox_inches="tight") + plt.close() + else: + plt.show() + + +# %% +# Setup model and optimizer +# ------------------------- +init_traj = initialize_2D_radial(16, 512).reshape(-1, 2).astype(np.float32) +model = Model(init_traj) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +schedulder = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1, end_factor=0.1, total_iters=100 +) + +# %% +# Setup data +# ---------- + +mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))[ + None +] +mri_2D = mri_2D / torch.linalg.norm(mri_2D) +model.eval() +recon = model(mri_2D) +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) +plot_state(axs, mri_2D, init_traj, recon) + +# %% +# Start training loop +# ------------------- +losses = [] +imgs = [] +model.train() +with tqdm(range(100), unit="steps") as tqdms: + for i in tqdms: + out = model(mri_2D) + loss = torch.norm(out - mri_2D[None]) + numpy_loss = loss.detach().cpu().numpy() + tqdms.set_postfix({"loss": numpy_loss}) + losses.append(numpy_loss) + optimizer.zero_grad() + loss.backward() + optimizer.step() + with torch.no_grad(): + # Clamp the value of trajectory between [-0.5, 0.5] + for param in model.parameters(): + param.clamp_(-0.5, 0.5) + schedulder.step() + # Generate images for gif + fig, axs = plt.subplots(2, 2, figsize=(10, 10)) + plot_state( + axs, + mri_2D, + model.trajectory.detach().cpu().numpy(), + out, + losses, + save_name=f"{i}.png", + ) + imgs.append(Image.open(f"/tmp/{i}.png")) + +# Make a GIF of all images. +imgs[0].save( + "mrinufft_learn_traj.gif", + save_all=True, + append_images=imgs[1:], + optimize=False, + duration=2, + loop=0, +) +# sphinx_gallery_start_ignore +# cleanup +import os +import shutil +from pathlib import Path + +for f in range(100): + f = f"/tmp/{f}.png" + try: + os.remove(f) + except OSError: + continue +# don't raise errors from pytest. This will only be executed for the sphinx gallery stuff +try: + final_dir = ( + Path(os.getcwd()).parent.parent + / "docs" + / "generated" + / "autoexamples" + / "GPU" + / "images" + ) + shutil.copyfile("mrinufft_learn_traj.gif", final_dir / "mrinufft_learn_traj.gif") +except FileNotFoundError: + pass + +# sphinx_gallery_end_ignore + +# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_traj.gif' + +# %% +# .. image-sg:: /generated/autoexamples/GPU/images/mrinufft_learn_traj.gif +# :alt: example learn_samples +# :srcset: /generated/autoexamples/GPU/images/mrinufft_learn_traj.gif +# :class: sphx-glr-single-img + +# %% +# Trained trajectory +# ------------------ +model.eval() +recon = model(mri_2D) +fig, axs = plt.subplots(2, 2, figsize=(10, 10)) +plot_state(axs, mri_2D, model.trajectory.detach().cpu().numpy(), recon, losses) +plt.show() diff --git a/examples/README.rst b/examples/README.rst index 2db1ab1f..aaa7ac77 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -1,3 +1,5 @@ +.. _general_examples: + Examples ======== diff --git a/examples/example_gif_2D.py b/examples/example_gif_2D.py index da31105f..d56760d7 100644 --- a/examples/example_gif_2D.py +++ b/examples/example_gif_2D.py @@ -193,7 +193,7 @@ def draw_frame(func, index, name, arg, save_dir="/tmp/"): os.remove(f) except OSError: continue -# don't raise errors from pytest. This will only be excecuted for the sphinx gallery stuff +# don't raise errors from pytest. This will only be executed for the sphinx gallery stuff try: final_dir = ( Path(os.getcwd()).parent / "docs" / "generated" / "autoexamples" / "images" diff --git a/examples/example_gif_3D.py b/examples/example_gif_3D.py index d8f1dd70..dc365e02 100644 --- a/examples/example_gif_3D.py +++ b/examples/example_gif_3D.py @@ -198,7 +198,7 @@ def draw_frame(func, index, name, arg, save_dir="/tmp/"): os.remove(f) except OSError: continue -# don't raise errors from pytest. This will only be excecuted for the sphinx gallery stuff +# don't raise errors from pytest. This will only be executed for the sphinx gallery stuff try: final_dir = ( Path(os.getcwd()).parent / "docs" / "generated" / "autoexamples" / "images" diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index bdccc75a..f29cc520 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -86,7 +86,6 @@ def proper_trajectory(trajectory, normalize="pi"): raise ValueError( "trajectory should be array_like, with the last dimension being coordinates" ) from e - new_traj = new_traj.reshape(-1, trajectory.shape[-1]) max_abs_val = xp.max(xp.abs(new_traj)) diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index ef8518d1..55895b5e 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -445,7 +445,7 @@ def read_arbgrad_rawdat( ----- This function requires the mapVBVD module to be installed. You can install it using the following command: - `pip install pymapVBVD` + `pip install pymapVBVD` """ data, hdr, twixObj = read_siemens_rawdat( filename=filename, diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index cf7a7b0b..0975eda1 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -50,19 +50,11 @@ def backward(ctx, dy): [ctx.nufft_op.op(grid_x[:, i, :, :]) for i in range(grid_x.size(1))], dim=0, ) - grad_traj = torch.mean( - torch.cat( - [ - torch.transpose( - (-1j * torch.conj(dy[:, i, :]) * nufft_dx_dom[:, i, :]), - 0, - 1, - )[None, ...] - for i in range(dy.shape[1]) - ], - dim=0, - ), - dim=0, + grad_traj = -1j * torch.conj(dy) * nufft_dx_dom + grad_traj = torch.transpose( + torch.sum(grad_traj, dim=1), + 0, + 1, ).to(NP2TORCH[ctx.nufft_op.dtype]) return grad_data, grad_traj, None @@ -86,7 +78,7 @@ def backward(ctx, dx): if ctx.nufft_op._grad_wrt_data: grad_data = ctx.nufft_op.op(dx) if ctx.nufft_op._grad_wrt_traj: - ctx.nufft_op.raw_op.toggle_grad_traj() + ctx.nufft_op.toggle_grad_traj() im_size = dx.size()[2:] factor = 1 if ctx.nufft_op.backend in ["gpunufft"]: @@ -100,20 +92,13 @@ def backward(ctx, dx): grid_dx = torch.conj(dx) * grid_r inufft_dx_dom = torch.cat( [ctx.nufft_op.op(grid_dx[:, i, :, :]) for i in range(grid_dx.size(1))], - dim=1, - ).squeeze() - inufft_dx_dom = inufft_dx_dom.reshape(y.shape[0], -1, y.shape[-1]) - grad_traj = torch.mean( - torch.cat( - [ - torch.transpose((1j * y[i] * inufft_dx_dom[i]), 0, 1)[None, ...] - for i in range(y.shape[0]) - ], - dim=0, - ), dim=0, - ).to(NP2TORCH[ctx.nufft_op.dtype]) - ctx.nufft_op.raw_op.toggle_grad_traj() + ) + grad_traj = 1j * y * inufft_dx_dom + grad_traj = torch.transpose(torch.sum(grad_traj, dim=1), 0, 1).to( + NP2TORCH[ctx.nufft_op.dtype] + ) + ctx.nufft_op.toggle_grad_traj() return grad_data, grad_traj, None @@ -136,6 +121,12 @@ def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]: self.nufft_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data + if wrt_traj: + # We initialize the samples as a torch tensor purely for autodiff purposes. + # It can also be converted later to nn.Parameter, in which case it is + # used for update also. + self._samples_torch = torch.Tensor(self.nufft_op.samples) + self._samples_torch.requires_grad = True def op(self, x): r"""Compute the forward image -> k-space.""" @@ -145,6 +136,19 @@ def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" return _NUFFT_ADJOP.apply(kspace, self.samples, self.nufft_op) + @property + def samples(self): + """Get the samples.""" + try: + return self._samples_torch + except AttributeError: + return self.nufft_op.samples + + @samples.setter + def samples(self, value): + self._samples_torch = value + self.nufft_op.samples = value.detach().cpu().numpy() + def __getattr__(self, name): - """Get the attribute from the root operator.""" + """Forward all other attributes to the nufft_op.""" return getattr(self.nufft_op, name) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 63baa185..ac3f5b14 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -841,3 +841,9 @@ def get_lipschitz_cst(self, max_iter=10, **kwargs): return power_method( max_iter, tmp_op, norm_func=lambda x: cp.linalg.norm(x.flatten()), x=x ) + + def toggle_grad_traj(self): + """Toggle between the gradient trajectory and the plan for type 1 transform.""" + if self.uses_sense: + self.smaps = self.smaps.conj() + self.raw_op.toggle_grad_traj() diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index 5d27402c..fbc3b389 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -162,3 +162,9 @@ def _make_plan_grad(self, **kwargs): **kwargs, ) self.raw_op._set_pts(typ="grad", samples=self.samples) + + def toggle_grad_traj(self): + """Toggle between the gradient trajectory and the plan for type 1 transform.""" + if self.uses_sense: + self.smaps = self.smaps.conj() + self.raw_op.toggle_grad_traj() diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 6b5853f4..2049da69 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -708,3 +708,9 @@ def _dc_host(self, image_data, obs_data): # data_consistency N / adj_op coil n # # This should bring some performance improvements, due to the asynchronous stuff. + + def toggle_grad_traj(self): + """Toggle the gradient trajectory of the operator.""" + if self.uses_sense: + self.smaps = self.smaps.conj() + self.raw_op.toggle_grad_traj() diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index cc4f4806..d85f4666 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -13,11 +13,11 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): Parameters ---------- - ktraj : array_like + ktraj: array_like The k-space coordinates for the Fourier transformation. - shape : tuple of int + shape: tuple of int The dimensions of the output Fourier matrix. - dtype : data-type, optional + dtype: data-type, optional The data type of the Fourier matrix, default is np.complex64. normalize : bool, optional If True, normalizes the matrix to maintain numerical stability. diff --git a/tests/helpers/asserts.py b/tests/helpers/asserts.py index 6087e4e4..ba3ed07f 100644 --- a/tests/helpers/asserts.py +++ b/tests/helpers/asserts.py @@ -40,8 +40,9 @@ def assert_almost_allclose(a, b, rtol, atol, mismatch, equal_nan=False): try: npt.assert_allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) except AssertionError as e: - e.message += "\nMismatched elements: " - e.message += f"{np.sum(~val)} > {mismatch}(={mismatch_perc*100:.2f}%)" + message = getattr(e, "message", "") + message += "\nMismatched elements: " + message += f"{np.sum(~val)} > {mismatch}(={mismatch_perc*100:.2f}%)" raise e diff --git a/tests/operators/test_autodiff.py b/tests/operators/test_autodiff.py index b3147ae0..dfeed4f3 100644 --- a/tests/operators/test_autodiff.py +++ b/tests/operators/test_autodiff.py @@ -8,11 +8,7 @@ from case_trajectories import CasesTrajectories from mrinufft.operators import get_operator -from helpers import ( - kspace_from_op, - image_from_op, - to_interface, -) +from helpers import kspace_from_op, image_from_op, to_interface, assert_almost_allclose TORCH_AVAILABLE = True @@ -75,21 +71,19 @@ def test_adjoint_and_grad(operator, interface): if operator.backend == "finufft" and "gpu" in interface: pytest.skip("GPU not supported for finufft backend") - if torch.is_tensor(operator.samples): - operator.samples = operator.samples.cpu().detach().numpy() - - operator.samples = to_interface(operator.samples, interface=interface) + if "gpu" in interface: + operator.samples = operator.samples.to("cuda") + else: + operator.samples = operator.samples.cpu() ksp_data = to_interface(kspace_from_op(operator), interface=interface) img_data = to_interface(image_from_op(operator), interface=interface) ksp_data.requires_grad = True - operator.samples.requires_grad = True - with torch.autograd.set_detect_anomaly(True): adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) if operator.smaps is not None: smaps = torch.from_numpy(operator.smaps).to(img_data.device) - adj_data_ndft_smpas = torch.cat( + adj_data_ndft_smaps = torch.cat( [ (ndft_matrix(operator).conj().T @ ksp_data[i].flatten()).reshape( img_data.shape @@ -98,7 +92,7 @@ def test_adjoint_and_grad(operator, interface): ], dim=0, ) - adj_data_ndft = torch.mean(smaps.conj() * adj_data_ndft_smpas, dim=0) + adj_data_ndft = torch.sum(smaps.conj() * adj_data_ndft_smaps, dim=0) else: adj_data_ndft = ( ndft_matrix(operator).conj().T @ ksp_data.flatten() @@ -106,6 +100,14 @@ def test_adjoint_and_grad(operator, interface): loss_nufft = torch.mean(torch.abs(adj_data - img_data) ** 2) loss_ndft = torch.mean(torch.abs(adj_data_ndft - img_data) ** 2) + assert_almost_allclose( + adj_data.cpu().detach(), + adj_data_ndft.cpu().detach(), + atol=1e-1, + rtol=1e-1, + mismatch=20, + ) + # Check if nufft and ndft w.r.t trajectory are close in the backprop gradient_ndft_ktraj = torch.autograd.grad( loss_ndft, operator.samples, retain_graph=True @@ -113,8 +115,12 @@ def test_adjoint_and_grad(operator, interface): gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True )[0] - assert_allclose( - gradient_ndft_ktraj.cpu().numpy(), gradient_nufft_ktraj.cpu().numpy(), atol=5e-1 + assert_almost_allclose( + gradient_ndft_ktraj.cpu().numpy(), + gradient_nufft_ktraj.cpu().numpy(), + atol=1e-2, + rtol=1e-2, + mismatch=20, ) # Check if nufft and ndft are close in the backprop @@ -134,15 +140,13 @@ def test_forward_and_grad(operator, interface): if operator.backend == "finufft" and "gpu" in interface: pytest.skip("GPU not supported for finufft backend") - if torch.is_tensor(operator.samples): - operator.samples = operator.samples.cpu().detach().numpy() - - operator.samples = to_interface(operator.samples, interface=interface) + if "gpu" in interface: + operator.samples = operator.samples.to("cuda") + else: + operator.samples = operator.samples.cpu() ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) img_data = to_interface(image_from_op(operator), interface=interface) - img_data.requires_grad = True - operator.samples.requires_grad = True with torch.autograd.set_detect_anomaly(True): if operator.smaps is not None and operator.n_coils > 1: @@ -166,6 +170,15 @@ def test_forward_and_grad(operator, interface): loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) + # FIXME: This check can be tighter for Nyquist cases + assert_almost_allclose( + ksp_data.cpu().detach(), + ksp_data_ndft.cpu().detach(), + atol=1e-1, + rtol=1e-1, + mismatch=20, + ) + # Check if nufft and ndft w.r.t trajectory are close in the backprop gradient_ndft_ktraj = torch.autograd.grad( loss_ndft, operator.samples, retain_graph=True diff --git a/tests/test_density.py b/tests/test_density.py index 35281859..1737ca0e 100644 --- a/tests/test_density.py +++ b/tests/test_density.py @@ -6,7 +6,7 @@ from case_trajectories import CasesTrajectories from helpers import assert_correlate -from mrinufft.density import cell_count, voronoi, pipe +from mrinufft.density import cell_count, voronoi from mrinufft.density.utils import normalize_weights from mrinufft._utils import proper_trajectory diff --git a/tests/test_ndft.py b/tests/test_ndft.py index d23591ea..cd21622e 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -75,4 +75,5 @@ def test_ndft_fft(kspace_grid, shape): if len(shape) >= 2: kspace = kspace.swapaxes(0, 1) kspace_fft = sp.fft.fftn(sp.fft.fftshift(img)) + assert_almost_allclose(kspace, kspace_fft, atol=1e-4, rtol=1e-4, mismatch=5)