From 9b2431b2226163b9f876ec5f2d7e0614efdca49a Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 6 Dec 2023 11:19:49 +0100 Subject: [PATCH] fix: use shared memory for the smaps. --- src/simfmri/handlers/acquisition/workers.py | 135 ++++++++++++-------- 1 file changed, 81 insertions(+), 54 deletions(-) diff --git a/src/simfmri/handlers/acquisition/workers.py b/src/simfmri/handlers/acquisition/workers.py index 961ca29a..bda34f21 100644 --- a/src/simfmri/handlers/acquisition/workers.py +++ b/src/simfmri/handlers/acquisition/workers.py @@ -2,6 +2,7 @@ from __future__ import annotations from multiprocessing import shared_memory +from contextlib import contextmanager import logging import warnings from typing import Any, Callable, Mapping, Generator @@ -12,7 +13,7 @@ from tqdm.auto import tqdm from simfmri.simulation import SimData - +from numpy.typing import DTypeLike from ._tools import TrajectoryGeneratorType # from mrinufft import get_operator @@ -198,6 +199,42 @@ def acq_cartesian( return kdata, kmask +def _init_sm( + name: str, shape: tuple[int, ...], dtype: DTypeLike +) -> shared_memory.SharedMemory: + """Initialize a shared memory buffer.""" + return shared_memory.SharedMemory( + name=name, + create=True, + size=np.prod(shape) * np.dtype(dtype).itemsize, + ) + + +@contextmanager +def shm_manager( + name: str, shape: tuple[int, ...], dtype: DTypeLike, unlink: bool = False +) -> Generator[np.ndarray]: + """Context manager for shared memory.""" + try: + shm = shared_memory.SharedMemory( + name=name, + create=False, + ) + except FileNotFoundError: + arr = None + else: + arr = np.ndarray(shape, buffer=shm.buf, dtype=dtype) + + # properly close stuff + try: + yield arr + finally: + del arr + shm.close() + if unlink: + shm.unlink() + + def acq_noncartesian( sim: SimData, trajectory_gen: TrajectoryGeneratorType, @@ -211,18 +248,20 @@ def acq_noncartesian( n_samples = np.prod(test_traj.shape[:-1]) dim = test_traj.shape[-1] - kdata_infos = ((n_kspace_frame, sim.n_coils, n_samples), np.complex64) - shm_kdata = shared_memory.SharedMemory( - name="kdata", - create=True, - size=np.prod(kdata_infos[0]) * np.dtype(kdata_infos[1]).itemsize, - ) - kmask_infos = ((n_kspace_frame, n_samples, dim), np.float32) - shm_kmask = shared_memory.SharedMemory( - name="kmask", - create=True, - size=np.prod(kmask_infos[0]) * np.dtype(kmask_infos[1]).itemsize, - ) + # Allocate kspace data, kspace mask and smaps in shared memory. + shm_infos = { + "kdata": ((n_kspace_frame, sim.n_coils, n_samples), np.complex64), + "kmask": ((n_kspace_frame, n_samples, dim), np.float32), + "smaps": ((sim.n_coils, *sim.shape), np.complex64), + } + _init_sm("kdata", *shm_infos["kdata"]) + _init_sm("kmask", *shm_infos["kmask"]) + if sim.smaps: + shm_smaps = _init_sm("smaps", *shm_infos["smaps"]) + smaps = np.ndarray( + shm_infos["smaps"][0], buffer=shm_smaps.buf, dtype=shm_infos["smaps"][1] + ) + smaps[:] = sim.smaps nufft_backend = kwargs.pop("backend") logger.debug("Using backend %s", nufft_backend) @@ -236,14 +275,7 @@ def acq_noncartesian( density=False, backend_name=nufft_backend, ) - if "gpunufft" in nufft_backend: - logger.debug("Using gpunufft, pinning smaps") - from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps - op_kwargs["pinned_smaps"] = make_pinned_smaps(sim.smaps) - op_kwargs["smaps"] = None - else: - op_kwargs["smaps"] = sim.smaps scheduler = kspace_bulk_shot(trajectory_gen, sim.n_frames, n_shot_sim_frame) with Parallel(n_jobs=n_jobs, backend="loky", mmap_mode="r") as par: par( @@ -252,8 +284,7 @@ def acq_noncartesian( shot_batch, shot_pos, op_kwargs, - kdata_infos, - kmask_infos, + shm_infos, ) for sim_frame, shot_batch, shot_pos in tqdm(work_generator(sim, scheduler)) ) @@ -263,19 +294,12 @@ def acq_noncartesian( get_reusable_executor().shutdown(wait=True) - kdata_ = np.ndarray(kdata_infos[0], buffer=shm_kdata.buf, dtype=kdata_infos[1]) - kmask_ = np.ndarray(kmask_infos[0], buffer=shm_kmask.buf, dtype=kmask_infos[1]) - - kdata = np.copy(kdata_) - kmask = np.copy(kmask_) - del kdata_ - del kmask_ - - shm_kdata.close() - shm_kmask.close() - shm_kdata.unlink() - shm_kmask.unlink() - + with ( + shm_manager("kdata", *shm_infos["kdata"], unlink=True) as kdata_, + shm_manager("kmask", *shm_infos["kmask"], unlink=True) as kmask_, + ): + kdata = np.copy(kdata_) + kmask = np.copy(kmask_) return kdata, kmask @@ -291,32 +315,35 @@ def _single_worker( shot_batch: np.ndarray, shot_pos: tuple[int, int], op_kwargs: Mapping[str, Any], - kdata_infos: tuple[tuple[int], np.Dtype], - kmask_infos: tuple[tuple[int], np.Dtype], + shm_infos: Mapping[str, tuple[tuple[int], np.DtypeLike]], ) -> None: """Perform a shot acquisition.""" - with warnings.catch_warnings(): + + if "gpunufft" in op_kwargs["backend_name"]: + op_kwargs["pinned_smaps"] = smaps + smaps = None + with (warnings.catch_warnings(), + shm_manager("kdata", *shm_infos["kdata"]) as kdata_, + shm_manager("kmask", *shm_infos["kmask"]) as kmask_, + shm_manager("smaps", *shm_infos["smaps"]) as smaps_, + + + : warnings.filterwarnings( "ignore", - "Samples will be rescaled to .*", category=UserWarning, module="mrinufft", ) - fourier_op = get_operator(samples=shot_batch, **op_kwargs) + + fourier_op = get_operator(samples=shot_batch, smaps=smaps, **op_kwargs) kspace = fourier_op.op(sim_frame) L = shot_batch.shape[1] - shm_kdata = shared_memory.SharedMemory(name="kdata", create=False) - shm_kmask = shared_memory.SharedMemory(name="kmask", create=False) - - kdata_ = np.ndarray(kdata_infos[0], buffer=shm_kdata.buf, dtype=kdata_infos[1]) - kmask_ = np.ndarray(kmask_infos[0], buffer=shm_kmask.buf, dtype=kmask_infos[1]) - - for i, (k, s) in enumerate(shot_pos): - kdata_[k, :, s * L : (s + 1) * L] = kspace[..., i * L : (i + 1) * L] - kmask_[k, s * L : (s + 1) * L] = shot_batch[i] - - del kdata_ - del kmask_ - shm_kdata.close() - shm_kmask.close() + # write to share memory shots location and values. + with ( + shm_manager("kdata", *shm_infos["kdata"]) as kdata_, + shm_manager("kmask", *shm_infos["kmask"]) as kmask_, + ): + for i, (k, s) in enumerate(shot_pos): + kdata_[k, :, s * L : (s + 1) * L] = kspace[..., i * L : (i + 1) * L] + kmask_[k, s * L : (s + 1) * L] = shot_batch[i]