Skip to content

Commit

Permalink
device placement fix, perf test
Browse files Browse the repository at this point in the history
  • Loading branch information
CatEek committed Feb 2, 2025
1 parent 52e946d commit f9f4a56
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 290 deletions.
14 changes: 10 additions & 4 deletions src/careamics/transforms/n2v_manipulate_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from careamics.config.transformations import N2VManipulateModel

from .pixel_manipulation_torch import (
median_manipulate_torch_vect,
median_manipulate_torch,
uniform_manipulate_torch,
)
from .struct_mask_parameters import StructMaskParameters
Expand Down Expand Up @@ -76,10 +76,16 @@ def __init__(
)

# PyTorch random generator
# TODO check
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
self.rng = (
torch.Generator().manual_seed(seed)
torch.Generator(device=device).manual_seed(seed)
if seed is not None
else torch.default_generator
else torch.Generator(device=device)
)

def __call__(
Expand Down Expand Up @@ -118,7 +124,7 @@ def __call__(
elif self.strategy == SupportedPixelManipulation.MEDIAN:
# Iterate over the channels to apply manipulation separately
for c in range(batch.shape[1]):
masked[:, c, ...], mask[:, c, ...] = median_manipulate_torch_vect(
masked[:, c, ...], mask[:, c, ...] = median_manipulate_torch(
batch=batch[:, c, ...],
mask_pixel_percentage=self.masked_pixel_percentage,
subpatch_size=self.roi_size,
Expand Down
214 changes: 34 additions & 180 deletions src/careamics/transforms/pixel_manipulation_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _apply_struct_mask_torch(
Patch with the structN2V mask applied.
"""
if rng is None:
rng = torch.default_generator
rng = torch.Generator(device=patch.device)

# Relative axis
moving_axis = -1 - struct_params.axis
Expand Down Expand Up @@ -130,7 +130,7 @@ def _get_stratified_coords_torch(
Array of coordinates of the masked pixels.
"""
if rng is None:
rng = torch.default_generator
rng = torch.Generator()

# Calculate the maximum distance between masked pixels. Inversely proportional to
# the percentage of masked pixels.
Expand Down Expand Up @@ -164,90 +164,27 @@ def _get_stratified_coords_torch(

# create a 2D meshgrid of coordinates
coordinate_grid_list = torch.meshgrid(*pixel_coords, indexing="ij")
coordinate_grid = torch.stack([g.flatten() for g in coordinate_grid_list], dim=-1)
coordinate_grid = torch.stack(
[g.flatten() for g in coordinate_grid_list], dim=-1
).to(rng.device)

# add a random jitter increment so that the coordinates do not lie on the grid
random_increment = torch.randint(
high=int(_odd_jitter_func_torch(float(max(steps)), rng)),
size=torch.tensor(coordinate_grid.shape).tolist(),
size=torch.tensor(coordinate_grid.shape).to(rng.device).tolist(),
generator=rng,
device=rng.device,
)
coordinate_grid += random_increment

# make sure no coordinate lie outside the range
return torch.clamp(
coordinate_grid,
torch.zeros_like(torch.tensor(shape)),
torch.tensor([v - 1 for v in shape]),
torch.zeros_like(torch.tensor(shape)).to(device=rng.device),
torch.tensor([v - 1 for v in shape]).to(device=rng.device),
)


def _create_subpatch_center_mask(
subpatch: torch.Tensor, center_coords: torch.Tensor
) -> torch.Tensor:
"""Create a mask with the center of the subpatch masked.
Parameters
----------
subpatch : torch.Tensor
Subpatch to be manipulated.
center_coords : torch.Tensor
Coordinates of the original center before possible crop.
Returns
-------
torch.Tensor
Mask with the center of the subpatch masked.
"""
mask = torch.ones(torch.tensor(subpatch.shape).tolist())
mask[tuple(center_coords)] = 0
return (mask != 0).to(torch.bool)


def _create_subpatch_struct_mask(
subpatch: torch.Tensor,
center_coords: torch.Tensor,
struct_params: StructMaskParameters,
) -> torch.Tensor:
"""Create a structN2V mask for the subpatch.
Parameters
----------
subpatch : torch.Tensor
Subpatch to be manipulated.
center_coords : torch.Tensor
Coordinates of the original center before possible crop.
struct_params : StructMaskParameters
Parameters for the structN2V mask (axis and span).
Returns
-------
torch.Tensor
StructN2V mask for the subpatch.
"""
# TODO no test for this function!
# Create a mask with the center of the subpatch masked
mask_placeholder = torch.ones(subpatch.shape)

# reshape to move the struct axis to the first position
mask_reshaped = torch.permute(mask_placeholder, struct_params.axis, 0)

# create the mask index for the struct axis
mask_index = slice(
max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
min(
1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
subpatch.shape[struct_params.axis],
),
)
mask_reshaped[struct_params.axis][mask_index] = 0

# reshape back to the original shape
mask = torch.permute(mask_reshaped, 0, struct_params.axis)

return (mask != 0).to(torch.bool) # type: ignore


def uniform_manipulate_torch(
patch: torch.Tensor,
mask_pixel_percentage: float,
Expand Down Expand Up @@ -288,7 +225,7 @@ def uniform_manipulate_torch(
tuple containing the manipulated patch and the corresponding mask.
"""
if rng is None:
rng = torch.default_generator
rng = torch.Generator(device=patch.device)
# TODO do we need seed ?

# create a copy of the patch
Expand Down Expand Up @@ -317,9 +254,10 @@ def uniform_manipulate_torch(
random_increment = roi_span[
torch.randint(
low=min(roi_span),
high=max(roi_span), # TODO check this, it may exclude one value
high=max(roi_span) + 1, # TODO check this, it may exclude one value
size=subpatch_centers.shape,
generator=rng,
device=patch.device,
)
]

Expand Down Expand Up @@ -348,106 +286,6 @@ def uniform_manipulate_torch(


def median_manipulate_torch(
patch: torch.Tensor,
mask_pixel_percentage: float,
subpatch_size: int = 11,
struct_params: Optional[StructMaskParameters] = None,
rng: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Manipulate pixels by replacing them with the median of their surrounding subpatch.
N2V2 version, manipulated pixels are selected randomly away from a grid with an
approximate uniform probability to be selected across the whole patch.
If `struct_params` is not None, an additional structN2V mask is applied to the data,
replacing the pixels in the mask with random values (excluding the pixel already
manipulated).
Parameters
----------
patch : torch.Tensor
Batch of input patches, 2D or 3D, shape (b, y, x) or (b, z, y, x).
mask_pixel_percentage : float
Approximate percentage of pixels to be masked.
subpatch_size : int
Size of the subpatch the new pixel value is sampled from, by default 11.
struct_params : StructMaskParameters or None, optional
Parameters for the structN2V mask (axis and span).
rng : torch.default_generator or None, optional
Random number generato, by default None.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
tuple containing the manipulated batch and the mask.
"""
if rng is None:
rng = torch.default_generator

# create a copy of the patch to apply the manipulation to
transformed_patch = patch.clone()

# get the coordinates of the future ROI centers
subpatch_centers = _get_stratified_coords_torch(
mask_pixel_percentage, patch.shape, rng
)
subpatch_centers = subpatch_centers.to(device=patch.device)

# arange the list of indices defining the side of ROI square
roi_span = torch.tensor(
[-(subpatch_size // 2), (subpatch_size // 2) + 1], device=patch.device
)
# define a range of coordinates for the subpatch
subpatch_crops_span_full = subpatch_centers[None, ...].T + roi_span
# subpatch_centers[..., None] + roi_span[None, None, :]
# # TODO refactor, improve

# clip the coordinates to the patch size
subpatch_crops_span_clipped = torch.clamp(
subpatch_crops_span_full,
torch.zeros_like(torch.tensor(patch.shape))[:, None, None].to(
device=patch.device
),
torch.tensor(patch.shape)[:, None, None].to(device=patch.device),
)
# TODO test and write better comments
for idx in range(subpatch_crops_span_clipped.shape[1]):
subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
idxs = [
slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
for x in subpatch_coords
]
subpatch = patch[tuple(idxs)]
subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]

if struct_params is None:
subpatch_mask = _create_subpatch_center_mask(
subpatch, subpatch_center_adjusted
)
else:
subpatch_mask = _create_subpatch_struct_mask(
subpatch, subpatch_center_adjusted, struct_params
)
# TODO do median on the whole transformed_patch
transformed_patch[tuple(subpatch_centers[idx])] = torch.median(
subpatch[subpatch_mask]
)

mask = torch.where(transformed_patch != patch, 1, 0).to(torch.uint8)

if struct_params is not None:
transformed_patch = _apply_struct_mask_torch(
transformed_patch, subpatch_centers, struct_params
)

return (
transformed_patch,
mask,
)


def median_manipulate_torch_vect(
batch: torch.Tensor,
mask_pixel_percentage: float,
subpatch_size: int = 11,
Expand Down Expand Up @@ -485,22 +323,24 @@ def median_manipulate_torch_vect(
# get the coordinates of the future ROI centers
subpatch_center_coordinates = _get_stratified_coords_torch(
mask_pixel_percentage, batch.shape, rng
).to(device=batch.device)
).to(
device=batch.device
) # (num_coordinates, batch + num_spatial_dims)

# Calculate the padding value for the input tensor
pad_value = subpatch_size // 2

# Generate all offsets for the ROIs
# Generate all offsets for the ROIs. Iteration starting from 1 to skip the batch
offsets = torch.meshgrid(
[
torch.arange(-pad_value, pad_value + 1, device=batch.device)
for i in range(1, subpatch_center_coordinates.shape[1])
for _ in range(1, subpatch_center_coordinates.shape[1])
],
indexing="ij",
)
offsets = torch.stack(
[axis_offset.flatten() for axis_offset in offsets], dim=1
) # (subpatch_size**2, num_spacial_dims)
) # (subpatch_size**2, num_spatial_dims)

# Create the list to assemble coordinates of the ROIs centers for each axis
coords_axes = []
Expand Down Expand Up @@ -536,9 +376,23 @@ def median_manipulate_torch_vect(
torch.arange(subpatch_size), torch.arange(subpatch_size), indexing="ij"
)
center_idx = subpatch_size // 2
span = (struct_params.span - 1) // 2
halfspan = (struct_params.span - 1) // 2

# Determine the axis along which to apply the mask
if struct_params.axis == 0:
center_axis = h
span_axis = w
else:
center_axis = w
span_axis = h

# Create the mask
struct_mask = (
~((w == center_idx) & (h >= center_idx - span) & (h <= center_idx + span))
~(
(center_axis == center_idx)
& (span_axis >= center_idx - halfspan)
& (span_axis <= center_idx + halfspan)
)
).flatten()
rois_filtered = rois[:, struct_mask]
else:
Expand Down
Loading

0 comments on commit f9f4a56

Please sign in to comment.