Skip to content

Commit

Permalink
enh: port from process pool into asyncio concurrent
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Markiewicz <[email protected]>
  • Loading branch information
oesteban and effigies committed Aug 6, 2024
1 parent 7c7608f commit b42b172
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 106 deletions.
147 changes: 55 additions & 92 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Resampling utilities."""

import asyncio
from os import cpu_count
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from pathlib import Path
from typing import Tuple
from typing import Callable, TypeVar

import numpy as np
from nibabel.loadsave import load as _nbload
Expand All @@ -27,65 +28,19 @@
_as_homogeneous,
)

R = TypeVar("R")

SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8
"""Minimum number of volumes to automatically serialize 4D transforms."""


def _apply_volume(
index: int,
data: np.ndarray,
targets: np.ndarray,
order: int = 3,
mode: str = "constant",
cval: float = 0.0,
prefilter: bool = True,
) -> Tuple[int, np.ndarray]:
"""
Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization.
async def worker(job: Callable[[], R], semaphore) -> R:
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, job)

Parameters
----------
index : :obj:`int`
The index of the volume to apply the interpolation to.
data : :obj:`~numpy.ndarray`
The input data array.
targets : :obj:`~numpy.ndarray`
The target coordinates for mapping.
order : :obj:`int`, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : :obj:`str`, optional
Determines how the input image is extended when the resamplings overflows
a border. One of ``'constant'``, ``'reflect'``, ``'nearest'``, ``'mirror'``,
or ``'wrap'``. Default is ``'constant'``.
cval : :obj:`float`, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: :obj:`bool`, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
Returns
-------
(:obj:`int`, :obj:`~numpy.ndarray`)
The index and the array resulting from the interpolation.
"""
return index, ndi.map_coordinates(
data,
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)


def apply(
async def apply(
transform: TransformBase,
spatialimage: str | Path | SpatialImage,
reference: str | Path | SpatialImage = None,
Expand All @@ -94,9 +49,9 @@ def apply(
cval: float = 0.0,
prefilter: bool = True,
output_dtype: np.dtype = None,
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
njobs: int = None,
dtype_width: int = 8,
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
max_concurrent: int = min(cpu_count(), 12),
) -> SpatialImage | np.ndarray:
"""
Apply a transformation to an image, resampling on the reference spatial object.
Expand All @@ -118,15 +73,15 @@ def apply(
or ``'wrap'``. Default is ``'constant'``.
cval : :obj:`float`, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: :obj:`bool`, optional
prefilter : :obj:`bool`, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
output_dtype: :obj:`~numpy.dtype`, optional
output_dtype : :obj:`~numpy.dtype`, optional
The dtype of the returned array or image, if specified.
If ``None``, the default behavior is to use the effective dtype of
the input image. If slope and/or intercept are defined, the effective
Expand All @@ -135,10 +90,17 @@ def apply(
If ``reference`` is defined, then the return value is an image, with
a data array of the effective dtype but with the on-disk dtype set to
the input image's on-disk dtype.
dtype_width: :obj:`int`
dtype_width : :obj:`int`
Cap the width of the input data type to the given number of bytes.
This argument is intended to work as a way to implement lower memory
requirements in resampling.
serialize_nvols : :obj:`int`
Minimum number of volumes in a 3D+t (that is, a series of 3D transformations
independent in time) to resample on a one-by-one basis.
Serialized resampling can be executed concurrently (parallelized) with
the argument ``max_concurrent``.
max_concurrent : :obj:`int`
Maximum number of 3D resamplings to be executed concurrently.
Returns
-------
Expand Down Expand Up @@ -201,46 +163,47 @@ def apply(
else None
)

njobs = cpu_count() if njobs is None or njobs < 1 else njobs
# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)

with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
results = []
for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]
semaphore = asyncio.Semaphore(max_concurrent)

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)
tasks = []
for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]

data_t = (
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)

results.append(
executor.submit(
_apply_volume,
t,
data_t,
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
data_t = (
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
)

tasks.append(
asyncio.create_task(
worker(
partial(
ndi.map_coordinates,
data_t,
targets,
output=resampled[..., t],
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
),
semaphore,
)
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)

for future in as_completed(results):
t, resampled_t = future.result()
resampled[..., t] = resampled_t
await asyncio.gather(*tasks)
else:
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)

Expand Down
15 changes: 1 addition & 14 deletions nitransforms/tests/test_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nitransforms import nonlinear as nitnl
from nitransforms import manip as nitm
from nitransforms import io
from nitransforms.resampling import apply, _apply_volume
from nitransforms.resampling import apply

RMSE_TOL_LINEAR = 0.09
RMSE_TOL_NONLINEAR = 0.05
Expand Down Expand Up @@ -363,16 +363,3 @@ def test_LinearTransformsMapping_apply(
reference=testdata_path / "sbref.nii.gz",
serialize_nvols=2 if serialize_4d else np.inf,
)


@pytest.mark.parametrize("t", list(range(4)))
def test_apply_helper(monkeypatch, t):
"""Ensure the apply helper function correctly just decorates with index."""
from nitransforms.resampling import ndi

def _retval(*args, **kwargs):
return 1

monkeypatch.setattr(ndi, "map_coordinates", _retval)

assert _apply_volume(t, None, None) == (t, 1)

0 comments on commit b42b172

Please sign in to comment.