diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index aea217dd8..220dde228 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -700,7 +700,7 @@ def _create_data_for_bmz( denormalize = Denormalize( mean=self.cfg.data_config.mean, std=self.cfg.data_config.std ) - input_patch, _ = denormalize(input_patch) + input_patch = denormalize(input_patch) elif self.train_datamodule is not None: input_patch, *_ = next(iter(self.train_datamodule.train_dataloader())) @@ -710,7 +710,7 @@ def _create_data_for_bmz( denormalize = Denormalize( mean=self.cfg.data_config.mean, std=self.cfg.data_config.std ) - input_patch, _ = denormalize(input_patch) + input_patch = denormalize(input_patch) else: # create a random input array input_patch = np.random.normal( diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py index 3fc6a3468..4d955339a 100644 --- a/src/careamics/config/tile_information.py +++ b/src/careamics/config/tile_information.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from pydantic import BaseModel, ConfigDict, field_validator class TileInformation(BaseModel): @@ -13,15 +13,17 @@ class TileInformation(BaseModel): This model is used to represent the information required to stitch back a tile into a larger image. It is used throughout the prediction pipeline of CAREamics. + + Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not + contain singleton dimensions. """ model_config = ConfigDict(validate_default=True) array_shape: Tuple[int, ...] - tiled: bool = False last_tile: bool = False - overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) - stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + overlap_crop_coords: Tuple[Tuple[int, ...], ...] + stitch_coords: Tuple[Tuple[int, ...], ...] @field_validator("array_shape") @classmethod @@ -48,59 +50,25 @@ def no_singleton_dimensions(cls, v: Tuple[int, ...]): raise ValueError("Array shape must not contain singleton dimensions.") return v - @field_validator("last_tile") - @classmethod - def only_if_tiled(cls, v: bool, values: ValidationInfo): - """ - Check that the last tile flag is only set if tiling is enabled. + def __eq__(self, other_tile: object): + """Check if two tile information objects are equal. Parameters ---------- - v : bool - Last tile flag. - values : ValidationInfo - Validation information. + other_tile : object + Tile information object to compare with. Returns ------- bool - The last tile flag. - """ - if not values.data["tiled"]: - return False - return v - - @field_validator("overlap_crop_coords", "stitch_coords") - @classmethod - def mandatory_if_tiled( - cls, v: Optional[Tuple[int, ...]], values: ValidationInfo - ) -> Optional[Tuple[int, ...]]: - """ - Check that the coordinates are not `None` if tiling is enabled. - - The method also return `None` if tiling is not enabled. - - Parameters - ---------- - v : Optional[Tuple[int, ...]] - Coordinates to check. - values : ValidationInfo - Validation information. - - Returns - ------- - Optional[Tuple[int, ...]] - The coordinates if tiling is enabled, otherwise `None`. - - Raises - ------ - ValueError - If the coordinates are `None` and tiling is enabled. + Whether the two tile information objects are equal. """ - if values.data["tiled"]: - if v is None: - raise ValueError("Value must be specified if tiling is enabled.") - - return v - else: - return None + if not isinstance(other_tile, TileInformation): + return NotImplemented + + return ( + self.array_shape == other_tile.array_shape + and self.last_tile == other_tile.last_tile + and self.overlap_crop_coords == other_tile.overlap_crop_coords + and self.stitch_coords == other_tile.stitch_coords + ) diff --git a/src/careamics/config/validators/validator_utils.py b/src/careamics/config/validators/validator_utils.py index da5eb0ae7..a8d88e782 100644 --- a/src/careamics/config/validators/validator_utils.py +++ b/src/careamics/config/validators/validator_utils.py @@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2( If the value is not a power of 2. """ if value < 8: - raise ValueError(f"Value must be non-zero positive (got {value}).") + raise ValueError(f"Value must be greater than 8 (got {value}).") if (value & (value - 1)) != 0: raise ValueError(f"Value must be a power of 2 (got {value}).") diff --git a/src/careamics/dataset/__init__.py b/src/careamics/dataset/__init__.py index b3c9cdbaf..43c39a3ef 100644 --- a/src/careamics/dataset/__init__.py +++ b/src/careamics/dataset/__init__.py @@ -1,6 +1,17 @@ """Dataset module.""" -__all__ = ["InMemoryDataset", "PathIterableDataset"] +__all__ = [ + "InMemoryDataset", + "InMemoryPredDataset", + "InMemoryTiledPredDataset", + "PathIterableDataset", + "IterableTiledPredDataset", + "IterablePredDataset", +] from .in_memory_dataset import InMemoryDataset +from .in_memory_pred_dataset import InMemoryPredDataset +from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset from .iterable_dataset import PathIterableDataset +from .iterable_pred_dataset import IterablePredDataset +from .iterable_tiled_pred_dataset import IterableTiledPredDataset diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index 242405769..e8e93692e 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -8,11 +8,13 @@ "read_tiff", "get_read_func", "read_zarr", + "iterate_over_files", ] from .dataset_utils import reshape_array from .file_utils import get_files_size, list_files, validate_source_target_files +from .iterate_over_files import iterate_over_files from .read_tiff import read_tiff from .read_utils import get_read_func from .read_zarr import read_zarr diff --git a/src/careamics/dataset/dataset_utils/file_utils.py b/src/careamics/dataset/dataset_utils/file_utils.py index 949b588e3..a37905a06 100644 --- a/src/careamics/dataset/dataset_utils/file_utils.py +++ b/src/careamics/dataset/dataset_utils/file_utils.py @@ -33,7 +33,7 @@ def list_files( data_type: Union[str, SupportedData], extension_filter: str = "", ) -> List[Path]: - """Create a recursive list of files in `data_path`. + """List recursively files in `data_path` and return a sorted list. If `data_path` is a file, its name is validated against the `data_type` using `fnmatch`, and the method returns `data_path` itself. diff --git a/src/careamics/dataset/dataset_utils/iterate_over_files.py b/src/careamics/dataset/dataset_utils/iterate_over_files.py new file mode 100644 index 000000000..b3e413f9b --- /dev/null +++ b/src/careamics/dataset/dataset_utils/iterate_over_files.py @@ -0,0 +1,83 @@ +"""Function to iterate over files.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Callable, Generator, List, Optional, Tuple, Union + +import numpy as np +from torch.utils.data import get_worker_info + +from careamics.config import DataConfig, InferenceConfig +from careamics.utils.logging import get_logger + +from .dataset_utils import reshape_array +from .read_tiff import read_tiff + +logger = get_logger(__name__) + + +def iterate_over_files( + data_config: Union[DataConfig, InferenceConfig], + data_files: List[Path], + target_files: Optional[List[Path]] = None, + read_source_func: Callable = read_tiff, +) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: + """Iterate over data source and yield whole reshaped images. + + Parameters + ---------- + data_config : Union[DataConfig, InferenceConfig] + Data configuration. + data_files : List[Path] + List of data files. + target_files : Optional[List[Path]] + List of target files, by default None. + read_source_func : Optional[Callable] + Function to read the source, by default read_tiff. + + Yields + ------ + np.ndarray + Image. + """ + # When num_workers > 0, each worker process will have a different copy of the + # dataset object + # Configuring each copy independently to avoid having duplicate data returned + # from the workers + worker_info = get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + num_workers = worker_info.num_workers if worker_info is not None else 1 + + # iterate over the files + for i, filename in enumerate(data_files): + # retrieve file corresponding to the worker id + if i % num_workers == worker_id: + try: + # read data + sample = read_source_func(filename, data_config.axes) + + # reshape array + reshaped_sample = reshape_array(sample, data_config.axes) + + # read target, if available + if target_files is not None: + if filename.name != target_files[i].name: + raise ValueError( + f"File {filename} does not match target file " + f"{target_files[i]}. Have you passed sorted " + f"arrays?" + ) + + # read target + target = read_source_func(target_files[i], data_config.axes) + + # reshape target + reshaped_target = reshape_array(target, data_config.axes) + + yield reshaped_sample, reshaped_target + else: + yield reshaped_sample, None + + except Exception as e: + logger.error(f"Error reading file {filename}: {e}") diff --git a/src/careamics/dataset/dataset_utils/read_tiff.py b/src/careamics/dataset/dataset_utils/read_tiff.py index ab557f2f9..0cea0f695 100644 --- a/src/careamics/dataset/dataset_utils/read_tiff.py +++ b/src/careamics/dataset/dataset_utils/read_tiff.py @@ -53,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray: else: raise ValueError(f"File {file_path} is not a valid tiff.") - # check dimensions - # TODO or should this really be done here? probably in the LightningDataModule - # TODO this should also be centralized somewhere else (validate_dimensions) - if len(array.shape) < 2 or len(array.shape) > 6: - raise ValueError( - f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for" - f"file {file_path})." - ) - return array diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 24ceb1e84..97c6bd04f 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -11,18 +11,15 @@ from careamics.transforms import Compose -from ..config import DataConfig, InferenceConfig -from ..config.tile_information import TileInformation -from ..config.transformations import NormalizeModel +from ..config import DataConfig from ..utils.logging import get_logger -from .dataset_utils import read_tiff, reshape_array +from .dataset_utils import read_tiff from .patching.patching import ( prepare_patches_supervised, prepare_patches_supervised_array, prepare_patches_unsupervised, prepare_patches_unsupervised_array, ) -from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) @@ -270,130 +267,3 @@ def split_dataset( dataset.patch_targets = val_targets return dataset - - -class InMemoryPredictionDataset(Dataset): - """ - Dataset storing data in memory and allowing generating patches from it. - - Parameters - ---------- - prediction_config : InferenceConfig - Prediction configuration. - inputs : np.ndarray - Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. - """ - - def __init__( - self, - prediction_config: InferenceConfig, - inputs: np.ndarray, - data_target: Optional[np.ndarray] = None, - read_source_func: Optional[Callable] = read_tiff, - ) -> None: - """Constructor. - - Parameters - ---------- - prediction_config : InferenceConfig - Prediction configuration. - inputs : np.ndarray - Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. - - Raises - ------ - ValueError - If data_path is not a directory. - """ - self.pred_config = prediction_config - self.input_array = inputs - self.axes = self.pred_config.axes - self.tile_size = self.pred_config.tile_size - self.tile_overlap = self.pred_config.tile_overlap - self.mean = self.pred_config.mean - self.std = self.pred_config.std - self.data_target = data_target - - # tiling only if both tile size and overlap are provided - self.tiling = self.tile_size is not None and self.tile_overlap is not None - - # read function - self.read_source_func = read_source_func - - # Generate patches - self.data = self._prepare_tiles() - self.mean, self.std = self.pred_config.mean, self.pred_config.std - - # get transforms - self.patch_transform = Compose( - transform_list=[NormalizeModel(mean=self.mean, std=self.std)], - ) - - def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: - """ - Iterate over data source and create an array of patches. - - Returns - ------- - List[XArrayTile] - List of tiles. - """ - # reshape array - reshaped_sample = reshape_array(self.input_array, self.axes) - - if self.tiling and self.tile_size is not None and self.tile_overlap is not None: - # generate patches, which returns a generator - patch_generator = extract_tiles( - arr=reshaped_sample, - tile_size=self.tile_size, - overlaps=self.tile_overlap, - ) - patches_list = list(patch_generator) - - if len(patches_list) == 0: - raise ValueError("No tiles generated, ") - - return patches_list - else: - array_shape = reshaped_sample.squeeze().shape - return [(reshaped_sample, TileInformation(array_shape=array_shape))] - - def __len__(self) -> int: - """ - Return the length of the dataset. - - Returns - ------- - int - Length of the dataset. - """ - return len(self.data) - - def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: - """ - Return the patch corresponding to the provided index. - - Parameters - ---------- - index : int - Index of the patch to return. - - Returns - ------- - Tuple[np.ndarray, TileInformation] - Transformed patch. - """ - tile_array, tile_info = self.data[index] - - # Apply transforms - transformed_tile, _ = self.patch_transform(patch=tile_array) - - return transformed_tile, tile_info diff --git a/src/careamics/dataset/in_memory_pred_dataset.py b/src/careamics/dataset/in_memory_pred_dataset.py new file mode 100644 index 000000000..2a0dc1ffd --- /dev/null +++ b/src/careamics/dataset/in_memory_pred_dataset.py @@ -0,0 +1,85 @@ +"""In-memory prediction dataset.""" + +from __future__ import annotations + +import numpy as np +from torch.utils.data import Dataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.transformations import NormalizeModel +from .dataset_utils import reshape_array + + +class InMemoryPredDataset(Dataset): + """Simple prediction dataset returning images along the sample axis. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + inputs: np.ndarray, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + + Raises + ------ + ValueError + If data_path is not a directory. + """ + self.pred_config = prediction_config + self.input_array = inputs + self.axes = self.pred_config.axes + self.mean, self.std = self.pred_config.mean, self.pred_config.std + + # Reshape data + self.data = reshape_array(self.input_array, self.axes) + + # get transforms + self.patch_transform = Compose( + transform_list=[NormalizeModel(mean=self.mean, std=self.std)], + ) + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) + + def __getitem__(self, index: int) -> np.ndarray: + """ + Return the patch corresponding to the provided index. + + Parameters + ---------- + index : int + Index of the patch to return. + + Returns + ------- + np.ndarray + Transformed patch. + """ + transformed_patch, _ = self.patch_transform(patch=self.data[[index]]) + + return transformed_patch diff --git a/src/careamics/dataset/in_memory_tiled_pred_dataset.py b/src/careamics/dataset/in_memory_tiled_pred_dataset.py new file mode 100644 index 000000000..14096c07c --- /dev/null +++ b/src/careamics/dataset/in_memory_tiled_pred_dataset.py @@ -0,0 +1,130 @@ +"""In-memory tiled prediction dataset.""" + +from __future__ import annotations + +from typing import List, Tuple + +import numpy as np +from torch.utils.data import Dataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.tile_information import TileInformation +from ..config.transformations import NormalizeModel +from .dataset_utils import reshape_array +from .tiling import extract_tiles + + +class InMemoryTiledPredDataset(Dataset): + """Prediction dataset storing data in memory and returning tiles of each image. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + inputs: np.ndarray, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + + Raises + ------ + ValueError + If data_path is not a directory. + """ + if ( + prediction_config.tile_size is None + or prediction_config.tile_overlap is None + ): + raise ValueError( + "Tile size and overlap must be provided to use the tiled prediction " + "dataset." + ) + + self.pred_config = prediction_config + self.input_array = inputs + self.axes = self.pred_config.axes + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + self.mean = self.pred_config.mean + self.std = self.pred_config.std + + # Generate patches + self.data = self._prepare_tiles() + self.mean, self.std = self.pred_config.mean, self.pred_config.std + + # get transforms + self.patch_transform = Compose( + transform_list=[NormalizeModel(mean=self.mean, std=self.std)], + ) + + def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: + """ + Iterate over data source and create an array of patches. + + Returns + ------- + List[XArrayTile] + List of tiles. + """ + # reshape array + reshaped_sample = reshape_array(self.input_array, self.axes) + + # generate patches, which returns a generator + patch_generator = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + patches_list = list(patch_generator) + + if len(patches_list) == 0: + raise ValueError("No tiles generated, ") + + return patches_list + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: + """ + Return the patch corresponding to the provided index. + + Parameters + ---------- + index : int + Index of the patch to return. + + Returns + ------- + Tuple[np.ndarray, TileInformation] + Transformed patch. + """ + tile_array, tile_info = self.data[index] + + # Apply transforms + transformed_tile, _ = self.patch_transform(patch=tile_array) + + return transformed_tile, tile_info diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 192e73864..6babe7568 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -4,85 +4,21 @@ import copy from pathlib import Path -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Tuple import numpy as np -from torch.utils.data import IterableDataset, get_worker_info +from torch.utils.data import IterableDataset +from careamics.config import DataConfig from careamics.transforms import Compose -from ..config import DataConfig, InferenceConfig -from ..config.tile_information import TileInformation -from ..config.transformations import NormalizeModel from ..utils.logging import get_logger -from .dataset_utils import read_tiff, reshape_array +from .dataset_utils import iterate_over_files, read_tiff from .patching.random_patching import extract_patches_random -from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) -def _iterate_over_files( - data_config: Union[DataConfig, InferenceConfig], - data_files: List[Path], - target_files: Optional[List[Path]] = None, - read_source_func: Callable = read_tiff, -) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: - """ - Iterate over data source and yield whole image. - - Parameters - ---------- - data_config : Union[DataConfig, InferenceConfig] - Data configuration. - data_files : List[Path] - List of data files. - target_files : Optional[List[Path]] - List of target files, by default None. - read_source_func : Optional[Callable] - Function to read the source, by default read_tiff. - - Yields - ------ - np.ndarray - Image. - """ - # When num_workers > 0, each worker process will have a different copy of the - # dataset object - # Configuring each copy independently to avoid having duplicate data returned - # from the workers - worker_info = get_worker_info() - worker_id = worker_info.id if worker_info is not None else 0 - num_workers = worker_info.num_workers if worker_info is not None else 1 - - # iterate over the files - for i, filename in enumerate(data_files): - # retrieve file corresponding to the worker id - if i % num_workers == worker_id: - try: - # read data - sample = read_source_func(filename, data_config.axes) - - # read target, if available - if target_files is not None: - if filename.name != target_files[i].name: - raise ValueError( - f"File {filename} does not match target file " - f"{target_files[i]}. Have you passed sorted " - f"arrays?" - ) - - # read target - target = read_source_func(target_files[i], data_config.axes) - - yield sample, target - else: - yield sample, None - - except Exception as e: - logger.error(f"Error reading file {filename}: {e}") - - class PathIterableDataset(IterableDataset): """ Dataset allowing extracting patches w/o loading whole data into memory. @@ -170,7 +106,7 @@ def _calculate_mean_and_std(self) -> Tuple[float, float]: means, stds = 0, 0 num_samples = 0 - for sample, _ in _iterate_over_files( + for sample, _ in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): means += sample.mean() @@ -203,20 +139,13 @@ def __iter__( ), "Mean and std must be provided" # iterate over files - for sample_input, sample_target in _iterate_over_files( + for sample_input, sample_target in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): - reshaped_sample = reshape_array(sample_input, self.data_config.axes) - reshaped_target = ( - None - if sample_target is None - else reshape_array(sample_target, self.data_config.axes) - ) - patches = extract_patches_random( - arr=reshaped_sample, + arr=sample_input, patch_size=self.data_config.patch_size, - target=reshaped_target, + target=sample_target, ) # iterate over patches @@ -317,132 +246,3 @@ def split_dataset( dataset.target_files = val_target_files return dataset - - -class IterablePredictionDataset(IterableDataset): - """ - Prediction dataset. - - Parameters - ---------- - prediction_config : InferenceConfig - Inference configuration. - src_files : List[Path] - List of data files. - read_source_func : Callable, optional - Read source function for custom types, by default read_tiff. - **kwargs : Any - Additional keyword arguments, unused. - - Attributes - ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - axes : str - Description of axes in format STCZYX. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - """ - - def __init__( - self, - prediction_config: InferenceConfig, - src_files: List[Path], - read_source_func: Callable = read_tiff, - **kwargs: Any, - ) -> None: - """Constructor. - - Parameters - ---------- - prediction_config : InferenceConfig - Inference configuration. - src_files : List[Path] - List of data files. - read_source_func : Callable, optional - Read source function for custom types, by default read_tiff. - **kwargs : Any - Additional keyword arguments, unused. - - Raises - ------ - ValueError - If mean and std are not provided in the inference configuration. - """ - self.prediction_config = prediction_config - self.data_files = src_files - self.axes = prediction_config.axes - self.tile_size = self.prediction_config.tile_size - self.tile_overlap = self.prediction_config.tile_overlap - self.read_source_func = read_source_func - - # tile only if both tile size and overlaps are provided - self.tile = self.tile_size is not None and self.tile_overlap is not None - - # check mean and std and create normalize transform - if self.prediction_config.mean is None or self.prediction_config.std is None: - raise ValueError("Mean and std must be provided for prediction.") - else: - self.mean = self.prediction_config.mean - self.std = self.prediction_config.std - - # instantiate normalize transform - self.patch_transform = Compose( - transform_list=[ - NormalizeModel( - mean=prediction_config.mean, std=prediction_config.std - ) - ], - ) - - def __iter__( - self, - ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: - """ - Iterate over data source and yield single patch. - - Yields - ------ - np.ndarray - Single patch. - """ - assert ( - self.mean is not None and self.std is not None - ), "Mean and std must be provided" - - for sample, _ in _iterate_over_files( - self.prediction_config, - self.data_files, - read_source_func=self.read_source_func, - ): - # reshape array - reshaped_sample = reshape_array(sample, self.axes) - - if ( - self.tile - and self.tile_size is not None - and self.tile_overlap is not None - ): - # generate patches, return a generator - patch_gen = extract_tiles( - arr=reshaped_sample, - tile_size=self.tile_size, - overlaps=self.tile_overlap, - ) - else: - # just wrap the sample in a generator with default tiling info - array_shape = reshaped_sample.squeeze().shape - patch_gen = ( - (reshaped_sample, TileInformation(array_shape=array_shape)) - for _ in range(1) - ) - - # apply transform to patches - for patch_array, tile_info in patch_gen: - transformed_patch, _ = self.patch_transform(patch=patch_array) - - yield transformed_patch, tile_info diff --git a/src/careamics/dataset/iterable_pred_dataset.py b/src/careamics/dataset/iterable_pred_dataset.py new file mode 100644 index 000000000..792dde1d2 --- /dev/null +++ b/src/careamics/dataset/iterable_pred_dataset.py @@ -0,0 +1,117 @@ +"""Iterable prediction dataset used to load data file by file.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Generator, List + +import numpy as np +from torch.utils.data import IterableDataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.transformations import NormalizeModel +from .dataset_utils import iterate_over_files, read_tiff + + +class IterablePredDataset(IterableDataset): + """Simple iterable prediction dataset. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Attributes + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + src_files: List[Path], + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Raises + ------ + ValueError + If mean and std are not provided in the inference configuration. + """ + self.prediction_config = prediction_config + self.data_files = src_files + self.axes = prediction_config.axes + self.read_source_func = read_source_func + + # check mean and std and create normalize transform + if self.prediction_config.mean is None or self.prediction_config.std is None: + raise ValueError("Mean and std must be provided for prediction.") + else: + self.mean = self.prediction_config.mean + self.std = self.prediction_config.std + + # instantiate normalize transform + self.patch_transform = Compose( + transform_list=[ + NormalizeModel( + mean=prediction_config.mean, std=prediction_config.std + ) + ], + ) + + def __iter__( + self, + ) -> Generator[np.ndarray, None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + np.ndarray + Single patch. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + for sample, _ in iterate_over_files( + self.prediction_config, + self.data_files, + read_source_func=self.read_source_func, + ): + # sample has S dimension + for i in range(sample.shape[0]): + + transformed_sample, _ = self.patch_transform(patch=sample[[i]]) + + yield transformed_sample diff --git a/src/careamics/dataset/iterable_tiled_pred_dataset.py b/src/careamics/dataset/iterable_tiled_pred_dataset.py new file mode 100644 index 000000000..fa2783f49 --- /dev/null +++ b/src/careamics/dataset/iterable_tiled_pred_dataset.py @@ -0,0 +1,135 @@ +"""Iterable tiled prediction dataset used to load data file by file.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Generator, List, Tuple + +import numpy as np +from torch.utils.data import IterableDataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.tile_information import TileInformation +from ..config.transformations import NormalizeModel +from .dataset_utils import iterate_over_files, read_tiff +from .tiling import extract_tiles + + +class IterableTiledPredDataset(IterableDataset): + """Tiled prediction dataset. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Attributes + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + src_files: List[Path], + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Raises + ------ + ValueError + If mean and std are not provided in the inference configuration. + """ + if ( + prediction_config.tile_size is None + or prediction_config.tile_overlap is None + ): + raise ValueError( + "Tile size and overlap must be provided for tiled prediction." + ) + + self.prediction_config = prediction_config + self.data_files = src_files + self.axes = prediction_config.axes + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + self.read_source_func = read_source_func + + # check mean and std and create normalize transform + if self.prediction_config.mean is None or self.prediction_config.std is None: + raise ValueError("Mean and std must be provided for prediction.") + else: + self.mean = self.prediction_config.mean + self.std = self.prediction_config.std + + # instantiate normalize transform + self.patch_transform = Compose( + transform_list=[ + NormalizeModel( + mean=prediction_config.mean, std=prediction_config.std + ) + ], + ) + + def __iter__( + self, + ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + Tuple[pnp.ndarray, TileInformation] + Single tile. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + for sample, _ in iterate_over_files( + self.prediction_config, + self.data_files, + read_source_func=self.read_source_func, + ): + # generate patches, return a generator of single tiles + patch_gen = extract_tiles( + arr=sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + + # apply transform to patches + for patch_array, tile_info in patch_gen: + transformed_patch, _ = self.patch_transform(patch=patch_array) + + yield transformed_patch, tile_info diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index d445c0ec3..d4f391cb8 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -23,6 +23,8 @@ def prepare_patches_supervised( """ Iterate over data source and create an array of patches and corresponding targets. + The lists of Paths should be pre-sorted. + Parameters ---------- train_files : List[Path] @@ -41,9 +43,6 @@ def prepare_patches_supervised( np.ndarray Array of patches. """ - train_files.sort() - target_files.sort() - means, stds, num_samples = 0, 0, 0 all_patches, all_targets = [], [] for train_filename, target_filename in zip(train_files, target_files): diff --git a/src/careamics/dataset/patching/validate_patch_dimension.py b/src/careamics/dataset/patching/validate_patch_dimension.py index 8174493a4..56fd6d698 100644 --- a/src/careamics/dataset/patching/validate_patch_dimension.py +++ b/src/careamics/dataset/patching/validate_patch_dimension.py @@ -45,18 +45,20 @@ def validate_patch_dimensions( if len(patch_size) != len(arr.shape[2:]): raise ValueError( f"There must be a patch size for each spatial dimensions " - f"(got {patch_size} patches for dims {arr.shape})." + f"(got {patch_size} patches for dims {arr.shape}). Check the axes order." ) # Sanity checks on patch sizes versus array dimension if is_3d_patch and patch_size[0] > arr.shape[-3]: raise ValueError( f"Z patch size is inconsistent with image shape " - f"(got {patch_size[0]} patches for dim {arr.shape[1]})." + f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes " + f"order." ) if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]: raise ValueError( f"At least one of YX patch dimensions is larger than the corresponding " - f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})." + f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). " + f"Check the axes order." ) diff --git a/src/careamics/dataset/tiling/__init__.py b/src/careamics/dataset/tiling/__init__.py new file mode 100644 index 000000000..f7b9643a1 --- /dev/null +++ b/src/careamics/dataset/tiling/__init__.py @@ -0,0 +1,11 @@ +"""Tiling functions.""" + +__all__ = [ + "stitch_prediction", + "extract_tiles", + "collate_tiles", +] + +from .collate_tiles import collate_tiles +from .stitch_prediction import stitch_prediction +from .tiled_patching import extract_tiles diff --git a/src/careamics/dataset/tiling/collate_tiles.py b/src/careamics/dataset/tiling/collate_tiles.py new file mode 100644 index 000000000..ceefc601f --- /dev/null +++ b/src/careamics/dataset/tiling/collate_tiles.py @@ -0,0 +1,33 @@ +"""Collate function for tiling.""" + +from typing import Any, List, Tuple + +import numpy as np +from torch.utils.data.dataloader import default_collate + +from careamics.config.tile_information import TileInformation + + +def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: + """ + Collate tiles received from CAREamics prediction dataloader. + + CAREamics prediction dataloader returns tuples of arrays and TileInformation. In + case of non-tiled data, this function will return the arrays. In case of tiled data, + it will return the arrays, the last tile flag, the overlap crop coordinates and the + stitch coordinates. + + Parameters + ---------- + batch : List[Tuple[np.ndarray, TileInformation], ...] + Batch of tiles. + + Returns + ------- + Any + Collated batch. + """ + new_batch = [tile for tile, _ in batch] + tiles_batch = [tile_info for _, tile_info in batch] + + return default_collate(new_batch), tiles_batch diff --git a/src/careamics/dataset/tiling/stitch_prediction.py b/src/careamics/dataset/tiling/stitch_prediction.py new file mode 100644 index 000000000..54f946042 --- /dev/null +++ b/src/careamics/dataset/tiling/stitch_prediction.py @@ -0,0 +1,55 @@ +"""Prediction utility functions.""" + +from typing import List + +import numpy as np + +from careamics.config.tile_information import TileInformation + + +def stitch_prediction( + tiles: List[np.ndarray], + tile_infos: List[TileInformation], +) -> np.ndarray: + """Stitch tiles back together to form a full image. + + Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a + singleton dimension. + + Parameters + ---------- + tiles : List[np.ndarray] + Cropped tiles and their respective stitching coordinates. + tile_infos : List[TileInformation] + List of information and coordinates obtained from + `dataset.tiled_patching.extract_tiles`. + + Returns + ------- + np.ndarray + Full image. + """ + # retrieve whole array size + input_shape = tile_infos[0].array_shape + predicted_image = np.zeros(input_shape, dtype=np.float32) + + for tile, tile_info in zip(tiles, tile_infos): + n_channels = tile.shape[0] + + # Compute coordinates for cropping predicted tile + slices = (slice(0, n_channels),) + tuple( + [slice(c[0], c[1]) for c in tile_info.overlap_crop_coords] + ) + + # Crop predited tile according to overlap coordinates + cropped_tile = tile[slices] + + # Insert cropped tile into predicted image using stitch coordinates + predicted_image[ + ( + ..., + *[slice(c[0], c[1]) for c in tile_info.stitch_coords], + ) + ] = cropped_tile.astype(np.float32) + + return predicted_image diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/tiling/tiled_patching.py similarity index 96% rename from src/careamics/dataset/patching/tiled_patching.py rename to src/careamics/dataset/tiling/tiled_patching.py index 890c7f616..10fe695cd 100644 --- a/src/careamics/dataset/patching/tiled_patching.py +++ b/src/careamics/dataset/tiling/tiled_patching.py @@ -84,15 +84,15 @@ def extract_tiles( tile_size: Union[List[int], Tuple[int, ...]], overlaps: Union[List[int], Tuple[int, ...]], ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: - """ - Generate tiles from the input array with specified overlap. + """Generate tiles from the input array with specified overlap. The tiles cover the whole array. The method returns a generator that yields tuples of array and tile information, the latter includes whether the tile is the last one, the coordinates of the overlap crop, and the coordinates of the stitched tile. - The array has shape C(Z)YX, where C can be a singleton. + Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX, + where C can be a singleton. Parameters ---------- @@ -155,7 +155,6 @@ def extract_tiles( # create tile information tile_info = TileInformation( array_shape=sample.squeeze().shape, - tiled=True, last_tile=last_tile, overlap_crop_coords=overlap_crop_coords, stitch_coords=stitch_coords, diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index e16ea39e4..57848c582 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -168,9 +168,9 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: mean=self._trainer.datamodule.predict_dataset.mean, std=self._trainer.datamodule.predict_dataset.std, ) - denormalized_output, _ = denorm(patch=output) + denormalized_output = denorm(patch=output.cpu()) - if len(aux) > 0: + if len(aux) > 0: # aux can be tiling information return denormalized_output, aux else: return denormalized_output diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index b7a730faa..af61d7456 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -1,68 +1,37 @@ """Prediction Lightning data modules.""" from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union import numpy as np import pytorch_lightning as L from torch.utils.data import DataLoader -from torch.utils.data.dataloader import default_collate from careamics.config import InferenceConfig from careamics.config.support import SupportedData -from careamics.config.tile_information import TileInformation +from careamics.dataset import ( + InMemoryPredDataset, + InMemoryTiledPredDataset, + IterablePredDataset, + IterableTiledPredDataset, +) from careamics.dataset.dataset_utils import ( get_read_func, list_files, ) -from careamics.dataset.in_memory_dataset import ( - InMemoryPredictionDataset, -) -from careamics.dataset.iterable_dataset import ( - IterablePredictionDataset, -) +from careamics.dataset.tiling.collate_tiles import collate_tiles from careamics.utils import get_logger -PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] +PredictDatasetType = Union[ + InMemoryPredDataset, + InMemoryTiledPredDataset, + IterablePredDataset, + IterableTiledPredDataset, +] logger = get_logger(__name__) -def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: - """ - Collate tiles received from CAREamics prediction dataloader. - - CAREamics prediction dataloader returns tuples of arrays and TileInformation. In - case of non-tiled data, this function will return the arrays. In case of tiled data, - it will return the arrays, the last tile flag, the overlap crop coordinates and the - stitch coordinates. - - Parameters - ---------- - batch : List[Tuple[np.ndarray, TileInformation], ...] - Batch of tiles. - - Returns - ------- - Any - Collated batch. - """ - first_tile_info: TileInformation = batch[0][1] - # if not tiled, then return arrays - if not first_tile_info.tiled: - arrays, _ = zip(*batch) - - return default_collate(arrays) - # else we explicit the last_tile flag and coordinates - else: - new_batch = [ - (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords) - for tile, t in batch - ] - - return default_collate(new_batch) - - class CAREamicsPredictData(L.LightningDataModule): """ CAREamics Lightning prediction data module. @@ -182,6 +151,9 @@ def __init__( self.tile_size = pred_config.tile_size self.tile_overlap = pred_config.tile_overlap + # check if it is tiled + self.tiled = self.tile_size is not None and self.tile_overlap is not None + # read source function if pred_config.data_type == SupportedData.CUSTOM: # mypy check @@ -212,17 +184,29 @@ def setup(self, stage: Optional[str] = None) -> None: """ # if numpy array if self.data_type == SupportedData.ARRAY: - # prediction dataset - self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( - prediction_config=self.prediction_config, - inputs=self.pred_data, - ) + if self.tiled: + self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) + else: + self.predict_dataset = InMemoryPredDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) else: - self.predict_dataset = IterablePredictionDataset( - prediction_config=self.prediction_config, - src_files=self.pred_files, - read_source_func=self.read_source_func, - ) + if self.tiled: + self.predict_dataset = IterableTiledPredDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) + else: + self.predict_dataset = IterablePredDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) def predict_dataloader(self) -> DataLoader: """ @@ -236,7 +220,7 @@ def predict_dataloader(self) -> DataLoader: return DataLoader( self.predict_dataset, batch_size=self.batch_size, - collate_fn=_collate_tiles, + collate_fn=collate_tiles if self.tiled else None, **self.dataloader_params, ) # TODO check workers are used diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py index ab44a17a8..9298a780e 100644 --- a/src/careamics/lightning_prediction_loop.py +++ b/src/careamics/lightning_prediction_loop.py @@ -1,14 +1,16 @@ """Lithning prediction loop allowing tiling.""" -from typing import Optional +from typing import List, Optional +import numpy as np import pytorch_lightning as L from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher from pytorch_lightning.loops.utilities import _no_grad_context from pytorch_lightning.trainer import call from pytorch_lightning.utilities.types import _PREDICT_OUTPUT -from careamics.prediction import stitch_prediction +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling import stitch_prediction class CAREamicsPredictionLoop(L.loops._PredictionLoop): @@ -37,11 +39,10 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: ######################################################## ################ CAREamics specific code ############### if len(self.predicted_array) == 1: - # TODO does this make sense to here? (force numpy array) - return self.predicted_array[0].numpy() + # single array, already a numpy array + return self.predicted_array[0] # todo why not return the list here? else: - # TODO revisit logic - return [element.numpy() for element in self.predicted_array] + return self.predicted_array ######################################################## return None @@ -65,8 +66,8 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: assert data_fetcher is not None self.predicted_array = [] - self.tiles = [] - self.stitching_data = [] + self.tiles: List[np.ndarray] = [] + self.tile_information: List[TileInformation] = [] while True: try: @@ -87,27 +88,34 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: ######################################################## ################ CAREamics specific code ############### - # TODO: next line is not compatible with muSplit is_tiled = len(self.predictions[batch_idx]) == 2 if is_tiled: - # extract the last tile flag and the coordinates (crop and stitch) - last_tile, *stitch_data = self.predictions[batch_idx][1] + # a numpy array of shape BC(Z)YX + tile_batch = self.predictions[batch_idx][0] - # append the tile and the coordinates to the lists - self.tiles.append(self.predictions[batch_idx][0]) - self.stitching_data.append(stitch_data) + # split the tiles into C(Z)YX (skip singleton S) and + # add them to the tiles list + self.tiles.extend( + np.split(tile_batch.numpy(), tile_batch.shape[0], axis=0)[0] + ) + + # tile information is passed as a list of list of TileInformation + # TODO why list of list? + tile_info = self.predictions[batch_idx][1][0] + self.tile_information.extend(tile_info) # if last tile, stitch the tiles and add array to the prediction - if any(last_tile): + last_tiles = [t.last_tile for t in self.tile_information] + if any(last_tiles): predicted_batches = stitch_prediction( - self.tiles, self.stitching_data + self.tiles, self.tile_information ) self.predicted_array.append(predicted_batches) self.tiles.clear() - self.stitching_data.clear() + self.tile_information.clear() else: # simply add the prediction to the list - self.predicted_array.append(self.predictions[batch_idx]) + self.predicted_array.append(self.predictions[batch_idx].numpy()) ######################################################## except StopIteration: break diff --git a/src/careamics/prediction/__init__.py b/src/careamics/prediction/__init__.py deleted file mode 100644 index 852e65de1..000000000 --- a/src/careamics/prediction/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Prediction functions.""" - -__all__ = [ - "stitch_prediction", -] - -from .stitch_prediction import stitch_prediction diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py deleted file mode 100644 index 1b4fe9690..000000000 --- a/src/careamics/prediction/stitch_prediction.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Prediction utility functions.""" - -from typing import List - -import numpy as np -import torch - - -def stitch_prediction( - tiles: List[torch.Tensor], - stitching_data: List[List[torch.Tensor]], -) -> torch.Tensor: - """ - Stitch tiles back together to form a full image. - - Parameters - ---------- - tiles : List[torch.Tensor] - Cropped tiles and their respective stitching coordinates. - stitching_data : List - List of information and coordinates obtained from - `dataset.tiled_patching.extract_tiles`. - - Returns - ------- - np.ndarray - Full image. - """ - # retrieve whole array size, there is two cases to consider: - # 1. the tiles are stored in a list - # 2. the tiles are stored in a list with batches along the first dim - if tiles[0].shape[0] > 1: - input_shape = np.array( - [el.numpy() for el in stitching_data[0][0][0]], dtype=int - ).squeeze() - else: - input_shape = np.array( - [el.numpy() for el in stitching_data[0][0]], dtype=int - ).squeeze() - - # TODO should use torch.zeros instead of np.zeros - predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32)) - - for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( - tiles, stitching_data - ): - for batch_idx in range(tile_batch.shape[0]): - # Compute coordinates for cropping predicted tile - slices = tuple( - [ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in overlap_crop_coords_batch - ] - ) - - # Crop predited tile according to overlap coordinates - cropped_tile = tile_batch[batch_idx].squeeze()[slices] - - # Insert cropped tile into predicted image using stitch coordinates - predicted_image[ - ( - ..., - *[ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in stitch_coords_batch - ], - ) - ] = cropped_tile.to(torch.float32) - - return predicted_image diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py index 1e24afd5b..10a1d7423 100644 --- a/src/careamics/transforms/normalize.py +++ b/src/careamics/transforms/normalize.py @@ -91,12 +91,12 @@ def _apply(self, patch: np.ndarray) -> np.ndarray: class Denormalize: """ - Denormalize an image or image patch. + Denormalize an image. Denormalization is performed expecting a zero mean and unit variance input. This transform expects C(Z)YX dimensions. - Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + Note that an epsilon value of 1e-6 is added to the standard deviation to avoid division by zero during the normalization step, which is taken into account during denormalization. @@ -133,27 +133,22 @@ def __init__( self.std = std self.eps = 1e-6 - def __call__( - self, patch: np.ndarray, target: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + def __call__(self, patch: np.ndarray) -> np.ndarray: """Apply the transform to the source patch and the target (optional). Parameters ---------- patch : np.ndarray Patch, 2D or 3D, shape C(Z)YX. - target : Optional[np.ndarray], optional - Target for the patch, by default None. Returns ------- - Tuple[np.ndarray, Optional[np.ndarray]] - Transformed patch and target. + np.ndarray + Transformed patch. """ norm_patch = self._apply(patch) - norm_target = self._apply(target) if target is not None else None - return norm_patch, norm_target + return norm_patch def _apply(self, patch: np.ndarray) -> np.ndarray: """ diff --git a/src/careamics/utils/ram.py b/src/careamics/utils/ram.py index 258ebc824..dfa84456a 100644 --- a/src/careamics/utils/ram.py +++ b/src/careamics/utils/ram.py @@ -5,11 +5,11 @@ def get_ram_size() -> int: """ - Get RAM size in bytes. + Get RAM size in mbytes. Returns ------- int RAM size in mbytes. """ - return psutil.virtual_memory().total / 1024**2 + return psutil.virtual_memory().available / 1024**2 diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py index 78b24cc80..1f10b24b5 100644 --- a/tests/config/test_tile_information.py +++ b/tests/config/test_tile_information.py @@ -6,45 +6,60 @@ def test_defaults(): """Test instantiating time information with defaults.""" - tile_info = TileInformation(array_shape=np.zeros((6, 6)).shape) - - assert tile_info.array_shape == (6, 6) - assert not tile_info.tiled - assert not tile_info.last_tile - assert tile_info.overlap_crop_coords is None - assert tile_info.stitch_coords is None - - -def test_tiled(): - """Test instantiating time information with parameters.""" tile_info = TileInformation( array_shape=np.zeros((6, 6)).shape, - tiled=True, - last_tile=True, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), ) assert tile_info.array_shape == (6, 6) - assert tile_info.tiled - assert tile_info.last_tile - assert tile_info.overlap_crop_coords == ((1, 2),) - assert tile_info.stitch_coords == ((3, 4),) - - -def test_validation_last_tile(): - """Test that last tile is only set if tiled is set.""" - tile_info = TileInformation(array_shape=(6, 6), last_tile=True) assert not tile_info.last_tile def test_error_on_coords(): - """Test than an error is raised if it is tiled but not coordinates are given.""" + """Test than an error is raised if no coordinates are given.""" with pytest.raises(ValueError): - TileInformation(array_shape=(6, 6), tiled=True) + TileInformation(array_shape=(6, 6)) def test_error_on_singleton_dims(): """Test that an error is raised if the array shape contains singleton dimensions.""" with pytest.raises(ValueError): - TileInformation(array_shape=(2, 1, 6, 6)) + TileInformation( + array_shape=(2, 1, 6, 6), + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + + +def test_tile_equality(): + """Test whether two tile information objects are equal.""" + t1 = TileInformation( + array_shape=(6, 6), + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + t2 = TileInformation( + array_shape=(6, 6), + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + assert t1 == t2 + + # inequality + t2.array_shape = (7, 7) + assert t1 != t2 + + t2.array_shape = (6, 6) + t2.last_tile = False + assert t1 != t2 + + t2.last_tile = True + t2.overlap_crop_coords = ((2, 3),) + assert t1 != t2 + + t2.overlap_crop_coords = ((1, 2),) + t2.stitch_coords = ((4, 5),) + assert t1 != t2 diff --git a/tests/dataset/dataset_utils/test_list_files.py b/tests/dataset/dataset_utils/test_list_files.py index 595fa4610..971a99ace 100644 --- a/tests/dataset/dataset_utils/test_list_files.py +++ b/tests/dataset/dataset_utils/test_list_files.py @@ -92,6 +92,10 @@ def test_list_multiple_files_tiff(tmp_path: Path): assert len(files) == 3 assert set(files) == set(ref_files) + # test that the files are sorted + assert files != ref_files + assert files == sorted(ref_files) + def test_list_single_file_custom(tmp_path): """Test listing a single custom file.""" diff --git a/tests/dataset/prediction/test_collate_tiles.py b/tests/dataset/prediction/test_collate_tiles.py new file mode 100644 index 000000000..483391fed --- /dev/null +++ b/tests/dataset/prediction/test_collate_tiles.py @@ -0,0 +1,75 @@ +import pytest + +from careamics.dataset.tiling import collate_tiles, extract_tiles + + +@pytest.mark.parametrize("n_channels", [1, 4]) +@pytest.mark.parametrize("batch", [1, 3]) +def test_collate_tiles_2d(ordered_array, n_channels, batch): + """Test that the collate tiles function collates tile information correctly.""" + tile_size = (4, 4) + tile_overlap = (2, 2) + shape = (1, n_channels, 8, 8) + + # create array + array = ordered_array(shape) + + # extract tiles + tiles = list(extract_tiles(array, tile_size=tile_size, overlaps=tile_overlap)) + + tiles_used = 0 + n_tiles = len(tiles) + while tiles_used < n_tiles: + # get a batch of tiles + batch_tiles = tiles[tiles_used : tiles_used + batch] + tiles_used += batch + + # collate the tiles + collated_tiles = collate_tiles(batch_tiles) + + # check the collated tiles + assert collated_tiles[0].shape == (batch, n_channels) + tile_size + + # check the tile info + tile_infos = collated_tiles[1] + assert len(tile_infos) == batch + + for i, t in enumerate(tile_infos): + for j in range(i + 1, len(tile_infos)): + assert t != tile_infos[j] + + +@pytest.mark.parametrize("n_channels", [1, 4]) +@pytest.mark.parametrize("batch", [1, 3]) +def test_collate_tiles_3d(ordered_array, n_channels, batch): + """Test that the collate tiles function collates tile information correctly.""" + tile_size = (4, 4, 4) + tile_overlap = (2, 2, 2) + shape = (1, n_channels, 8, 8, 8) + + # create array + array = ordered_array(shape) + + # extract tiles + tiles = list(extract_tiles(array, tile_size=tile_size, overlaps=tile_overlap)) + + tiles_used = 0 + n_tiles = len(tiles) + while tiles_used < n_tiles: + # get a batch of tiles + batch_tiles = tiles[tiles_used : tiles_used + batch] + tiles_used += batch + + # collate the tiles + collated_tiles = collate_tiles(batch_tiles) + + # check the collated tiles + assert collated_tiles[0].shape == (batch, n_channels) + tile_size + + # check the tile info + tile_infos = collated_tiles[1] + assert len(tile_infos) == batch + + for i, t in enumerate(tile_infos): + for j in range(i + 1, len(tile_infos)): + assert t != tile_infos[j] diff --git a/tests/dataset/prediction/test_stitch_prediction.py b/tests/dataset/prediction/test_stitch_prediction.py new file mode 100644 index 000000000..062f48ec3 --- /dev/null +++ b/tests/dataset/prediction/test_stitch_prediction.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest + +from careamics.dataset.tiling import extract_tiles, stitch_prediction + + +@pytest.mark.parametrize( + "input_shape, tile_size, overlaps", + [ + ((1, 1, 8, 8), (4, 4), (2, 2)), + ((1, 2, 8, 8), (4, 4), (2, 2)), + ((2, 1, 8, 8), (4, 4), (2, 2)), + ((2, 2, 8, 8), (4, 4), (2, 2)), + ((1, 1, 7, 9), (4, 4), (2, 2)), + ((1, 3, 7, 9), (4, 4), (2, 2)), + ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), + ((1, 1, 321, 481), (256, 256), (48, 48)), + ((2, 1, 321, 481), (256, 256), (48, 48)), + ((1, 4, 321, 481), (256, 256), (48, 48)), + ((4, 3, 321, 481), (256, 256), (48, 48)), + ], +) +def test_stitch_tiles(ordered_array, input_shape, tile_size, overlaps): + """Test stitching tiles back together.""" + arr = ordered_array(input_shape, dtype=int) + n_samples = input_shape[0] + + # extract tiles + all_tiles = list(extract_tiles(arr, tile_size, overlaps)) + + tiles = [] + tile_infos = [] + sample_id = 0 + for tile, tile_info in all_tiles: + # create lists mimicking the output of the prediction loop + tiles.append(tile) + tile_infos.append(tile_info) + + # if we reached the last tile + if tile_info.last_tile: + result = stitch_prediction(tiles, tile_infos) + + # check equality with the correct sample + assert np.array_equal(result, arr[sample_id].squeeze()) + sample_id += 1 + + # clear the lists + tiles.clear() + tile_infos.clear() + + assert sample_id == n_samples diff --git a/tests/dataset/patching/test_tiled_patching.py b/tests/dataset/prediction/test_tiled_patching.py similarity index 98% rename from tests/dataset/patching/test_tiled_patching.py rename to tests/dataset/prediction/test_tiled_patching.py index a7e135d4f..920fa907a 100644 --- a/tests/dataset/patching/test_tiled_patching.py +++ b/tests/dataset/prediction/test_tiled_patching.py @@ -2,7 +2,7 @@ import pytest from careamics.config.tile_information import TileInformation -from careamics.dataset.patching.tiled_patching import ( +from careamics.dataset.tiling.tiled_patching import ( _compute_crop_and_stitch_coords_1d, extract_tiles, ) diff --git a/tests/dataset/test_in_memory_pred_dataset.py b/tests/dataset/test_in_memory_pred_dataset.py new file mode 100644 index 000000000..46099cbe5 --- /dev/null +++ b/tests/dataset/test_in_memory_pred_dataset.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest + +from careamics.config import InferenceConfig +from careamics.dataset import InMemoryPredDataset + + +@pytest.mark.parametrize( + "shape, axes, expected_shape", + [ + ((16, 16), "YX", (1, 1, 16, 16)), + ((3, 16, 16), "CYX", (1, 3, 16, 16)), + ((8, 16, 16), "ZYX", (1, 1, 8, 16, 16)), + ((3, 8, 16, 16), "CZYX", (1, 3, 8, 16, 16)), + ((4, 16, 16), "SYX", (1, 1, 16, 16)), + ((4, 3, 16, 16), "SCYX", (1, 3, 16, 16)), + ((4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), + ], +) +def test_correct_normalized_outputs(shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + # check expected length + if "S" in axes: + # find index of S and check shape + idx = axes.index("S") + n_patches = shape[idx] + else: + n_patches = 1 + + # create array + array = 255 * rng.random(shape) + + # create config + config = InferenceConfig( + data_type="array", + axes=axes, + mean=np.mean(array), + std=np.std(array), + ) + + # create dataset + dataset = InMemoryPredDataset(config, array) + + # check length + assert len(dataset) == n_patches + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img = dataset[i] + + # check that it has the correct shape + assert img.shape == expected_shape + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.1) + assert np.isclose(np.std(img), 1, atol=0.1) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2 = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/dataset/test_in_memory_tiled_pred_dataset.py b/tests/dataset/test_in_memory_tiled_pred_dataset.py new file mode 100644 index 000000000..0d4f53817 --- /dev/null +++ b/tests/dataset/test_in_memory_tiled_pred_dataset.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest + +from careamics.config import InferenceConfig +from careamics.dataset import InMemoryTiledPredDataset + + +# TODO extract tiles is returning C(Z)YX and no singleton S! +@pytest.mark.parametrize( + "shape, axes, expected_shape", + [ + ((16, 16), "YX", (1, 16, 16)), + ((3, 16, 16), "CYX", (3, 16, 16)), + ((16, 16, 16), "ZYX", (1, 16, 16, 16)), + ((3, 16, 16, 16), "CZYX", (3, 16, 16, 16)), + ((4, 16, 16), "SYX", (1, 16, 16)), + ((4, 3, 16, 16), "SCYX", (3, 16, 16)), + ((4, 3, 16, 16, 16), "SCZYX", (3, 16, 16, 16)), + ], +) +def test_correct_normalized_outputs(shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + tile_size = (8, 8, 8) if "Z" in axes else (8, 8) + tile_overlap = (4, 4, 4) if "Z" in axes else (4, 4) + + # check expected length + n_tiles = np.prod( + np.ceil( + (expected_shape[1:] - np.array(tile_overlap)) + / (np.array(tile_size) - np.array(tile_overlap)) + ) + ).astype(int) + + # check number of samples + if "S" in axes: + # get index + idx = axes.index("S") + n_samples = shape[idx] + else: + n_samples = 1 + + # check number of channels + if "C" in axes: + # get index + idx = axes.index("C") + n_channels = shape[idx] + else: + n_channels = 1 + + # create array + array = 255 * rng.random(shape) + + # create config + config = InferenceConfig( + data_type="array", + axes=axes, + mean=np.mean(array), + std=np.std(array), + tile_size=tile_size, + tile_overlap=tile_overlap, + ) + + # create dataset + dataset = InMemoryTiledPredDataset(config, array) + + # check length + assert len(dataset) == n_samples * n_tiles + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img, _ = dataset[i] + + # check that it has the correct shape + assert img.shape == (n_channels,) + tile_size + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.25) + assert np.isclose(np.std(img), 1, atol=0.2) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2, _ = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/dataset/test_iterable_pred_dataset.py b/tests/dataset/test_iterable_pred_dataset.py new file mode 100644 index 000000000..c267e3ba4 --- /dev/null +++ b/tests/dataset/test_iterable_pred_dataset.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest +import tifffile + +from careamics.config import InferenceConfig +from careamics.dataset import IterablePredDataset + + +@pytest.mark.parametrize( + "n_files, shape, axes, expected_shape", + [ + (1, (16, 16), "YX", (1, 1, 16, 16)), + (1, (3, 16, 16), "CYX", (1, 3, 16, 16)), + (1, (8, 16, 16), "ZYX", (1, 1, 8, 16, 16)), + (1, (3, 8, 16, 16), "CZYX", (1, 3, 8, 16, 16)), + (1, (4, 16, 16), "SYX", (1, 1, 16, 16)), + (1, (4, 3, 16, 16), "SCYX", (1, 3, 16, 16)), + (1, (4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), + (3, (16, 16), "YX", (1, 1, 16, 16)), + (3, (3, 16, 16), "CYX", (1, 3, 16, 16)), + (3, (8, 16, 16), "ZYX", (1, 1, 8, 16, 16)), + (3, (3, 8, 16, 16), "CZYX", (1, 3, 8, 16, 16)), + (3, (4, 16, 16), "SYX", (1, 1, 16, 16)), + (3, (4, 3, 16, 16), "SCYX", (1, 3, 16, 16)), + (3, (4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), + ], +) +def test_correct_normalized_outputs(tmp_path, n_files, shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + # check expected length + if "S" in axes: + # find index of S and check shape + idx = axes.index("S") + n_patches = shape[idx] + else: + n_patches = 1 + + # create array + new_shape = (n_files,) + shape + array = 255 * rng.random(new_shape) + + # create config + config = InferenceConfig( + data_type="tiff", + axes=axes, + mean=np.mean(array), + std=np.std(array), + ) + + files = [] + for i in range(n_files): + file = tmp_path / f"file_{i}.tif" + tifffile.imwrite(file, array[i]) + files.append(file) + + # create dataset + dataset = IterablePredDataset(config, files) + + # get all images + dataset = list(dataset) + + # check length + assert len(dataset) == n_files * n_patches + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img = dataset[i] + + # check that it has the correct shape + assert img.shape == expected_shape + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.1) + assert np.isclose(np.std(img), 1, atol=0.1) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2 = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/dataset/test_iterable_tiled_pred_dataset.py b/tests/dataset/test_iterable_tiled_pred_dataset.py new file mode 100644 index 000000000..dcc174571 --- /dev/null +++ b/tests/dataset/test_iterable_tiled_pred_dataset.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest +import tifffile + +from careamics.config import InferenceConfig +from careamics.dataset import IterableTiledPredDataset + + +# TODO extract tiles is returning C(Z)YX and no singleton S! +@pytest.mark.parametrize( + "n_files, shape, axes, expected_shape", + [ + (1, (16, 16), "YX", (1, 16, 16)), + (1, (3, 16, 16), "CYX", (3, 16, 16)), + (1, (8, 16, 16), "ZYX", (1, 8, 16, 16)), + (1, (3, 8, 16, 16), "CZYX", (3, 8, 16, 16)), + (1, (4, 16, 16), "SYX", (1, 16, 16)), + (1, (4, 3, 16, 16), "SCYX", (3, 16, 16)), + (1, (4, 3, 8, 16, 16), "SCZYX", (3, 8, 16, 16)), + (3, (16, 16), "YX", (1, 16, 16)), + (3, (3, 16, 16), "CYX", (3, 16, 16)), + (3, (8, 16, 16), "ZYX", (1, 8, 16, 16)), + (3, (3, 8, 16, 16), "CZYX", (3, 8, 16, 16)), + (3, (4, 16, 16), "SYX", (1, 16, 16)), + (3, (4, 3, 16, 16), "SCYX", (3, 16, 16)), + (3, (4, 3, 8, 16, 16), "SCZYX", (3, 8, 16, 16)), + ], +) +def test_correct_normalized_outputs(tmp_path, n_files, shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + tile_size = (8, 8, 8) if "Z" in axes else (8, 8) + tile_overlap = (4, 4, 4) if "Z" in axes else (4, 4) + + # check expected length + n_tiles = np.prod( + np.ceil( + (expected_shape[1:] - np.array(tile_overlap)) + / (np.array(tile_size) - np.array(tile_overlap)) + ) + ).astype(int) + + # check number of samples + if "S" in axes: + # get index + idx = axes.index("S") + n_samples = shape[idx] + else: + n_samples = 1 + + # check number of channels + if "C" in axes: + # get index + idx = axes.index("C") + n_channels = shape[idx] + else: + n_channels = 1 + + # create array + new_shape = (n_files,) + shape + array = 255 * rng.random(new_shape) + + # create config + config = InferenceConfig( + data_type="tiff", + axes=axes, + mean=np.mean(array), + std=np.std(array), + tile_size=tile_size, + tile_overlap=tile_overlap, + ) + + files = [] + for i in range(n_files): + file = tmp_path / f"file_{i}.tif" + tifffile.imwrite(file, array[i]) + files.append(file) + + # create dataset + dataset = IterableTiledPredDataset(config, files) + + # get all images + dataset = list(dataset) + + # check length + assert len(dataset) == n_files * n_samples * n_tiles + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img, _ = dataset[i] + + # check that it has the correct shape + assert img.shape == (n_channels,) + tile_size + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.25) + assert np.isclose(np.std(img), 1, atol=0.2) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2, _ = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/model_io/test_bmz_io.py b/tests/model_io/test_bmz_io.py index a031aa0d2..bf48b2e11 100644 --- a/tests/model_io/test_bmz_io.py +++ b/tests/model_io/test_bmz_io.py @@ -33,7 +33,7 @@ def test_state_dict_io(tmp_path, pre_trained): def test_bmz_io(tmp_path, pre_trained): """Test exporting and loading to the BMZ.""" # training data - train_array = np.ones((32, 32), dtype=np.float32) + train_array = np.ones((16, 16), dtype=np.float32) # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) diff --git a/tests/prediction/test_stitch_prediction.py b/tests/prediction/test_stitch_prediction.py deleted file mode 100644 index 4908af233..000000000 --- a/tests/prediction/test_stitch_prediction.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from torch import from_numpy, tensor - -from careamics.dataset.patching.tiled_patching import extract_tiles -from careamics.prediction.stitch_prediction import stitch_prediction - - -@pytest.mark.parametrize( - "input_shape, tile_size, overlaps", - [ - ((1, 1, 8, 8), (4, 4), (2, 2)), - ((1, 1, 8, 8), (4, 4), (2, 2)), - ((1, 1, 7, 9), (4, 4), (2, 2)), - ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), - ((1, 1, 321, 481), (256, 256), (48, 48)), - ], -) -def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps): - """Test calculating stitching coordinates.""" - arr = ordered_array(input_shape, dtype=int) - tiles = [] - stitching_data = [] - - # extract tiles - tile_generator = extract_tiles(arr, tile_size, overlaps) - - # Assemble all tiles as it is done during the prediction stage - for tile_data, tile_info in tile_generator: - tiles.append(from_numpy(tile_data)) # need to convert to torch.Tensor - stitching_data.append( - ( # this is way too wacky - [tensor(i) for i in input_shape], # need to convert to torch.Tensor - [[tensor([j]) for j in i] for i in tile_info.overlap_crop_coords], - [[tensor([j]) for j in i] for i in tile_info.stitch_coords], - ) - ) - - # compute stitching coordinates, it returns a torch.Tensor - result = stitch_prediction(tiles, stitching_data) - - assert (result.numpy() == arr).all() diff --git a/tests/test_careamist.py b/tests/test_careamist.py index c2fafa5d5..d1564ed2d 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -550,8 +550,9 @@ def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): # predict CAREamist predicted = careamist.predict(train_array) + predicted_squeeze = [p.squeeze() for p in predicted] - assert predicted.squeeze().shape == train_array.shape + assert np.array(predicted_squeeze).shape == train_array.shape # export to BMZ careamist.export_to_bmz( @@ -563,6 +564,44 @@ def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() +@pytest.mark.parametrize("independent_channels", [False, True]) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_predict_tiled_channel( + tmp_path: Path, + minimum_configuration: dict, + independent_channels: bool, + batch_size: int, +): + """Test that CAREamics can be trained on arrays with channels.""" + # training data + train_array = random_array((3, 32, 32)) + val_array = random_array((3, 32, 32)) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "CYX" + config.algorithm_config.model.in_channels = 3 + config.algorithm_config.model.num_classes = 3 + config.algorithm_config.model.independent_channels = independent_channels + config.data_config.batch_size = batch_size + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # predict CAREamist + predicted = careamist.predict( + train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4) + ) + + assert predicted.squeeze().shape == train_array.shape + + @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): """Test that CAREamics can predict with tiff files.""" diff --git a/tests/test_lightning_prediction_datamodule.py b/tests/test_lightning_prediction_datamodule.py index cf9e4881d..03edcd324 100644 --- a/tests/test_lightning_prediction_datamodule.py +++ b/tests/test_lightning_prediction_datamodule.py @@ -63,7 +63,7 @@ def test_wrapper_instantiated_with_tiling(simple_array): assert len(list(data_module.predict_dataloader())) == 2 -def test_lwrapper_instantiated_without_tiling(simple_array): +def test_wrapper_instantiated_without_tiling(simple_array): """Test that the data module is created correctly with an array.""" # create data module data_module = PredictDataWrapper( diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py index ffdb4bb6f..057444c7c 100644 --- a/tests/transforms/test_normalize.py +++ b/tests/transforms/test_normalize.py @@ -26,5 +26,5 @@ def test_normalize_denormalize(): ) # Apply the denormalize transform - denormalized, _ = denorm(patch=normalized) + denormalized = denorm(patch=normalized) assert np.isclose(denormalized, array).all()