From e474906c0b37a6c07b841f5d0b981d9cfe236c5e Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:37:33 +0200 Subject: [PATCH 01/14] (chore): add test for sorting, remove unnecessary calls to sort --- src/careamics/dataset/dataset_utils/file_utils.py | 2 +- src/careamics/dataset/patching/patching.py | 5 ++--- tests/dataset/dataset_utils/test_list_files.py | 4 ++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/careamics/dataset/dataset_utils/file_utils.py b/src/careamics/dataset/dataset_utils/file_utils.py index 949b588e3..a37905a06 100644 --- a/src/careamics/dataset/dataset_utils/file_utils.py +++ b/src/careamics/dataset/dataset_utils/file_utils.py @@ -33,7 +33,7 @@ def list_files( data_type: Union[str, SupportedData], extension_filter: str = "", ) -> List[Path]: - """Create a recursive list of files in `data_path`. + """List recursively files in `data_path` and return a sorted list. If `data_path` is a file, its name is validated against the `data_type` using `fnmatch`, and the method returns `data_path` itself. diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index d445c0ec3..d4f391cb8 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -23,6 +23,8 @@ def prepare_patches_supervised( """ Iterate over data source and create an array of patches and corresponding targets. + The lists of Paths should be pre-sorted. + Parameters ---------- train_files : List[Path] @@ -41,9 +43,6 @@ def prepare_patches_supervised( np.ndarray Array of patches. """ - train_files.sort() - target_files.sort() - means, stds, num_samples = 0, 0, 0 all_patches, all_targets = [], [] for train_filename, target_filename in zip(train_files, target_files): diff --git a/tests/dataset/dataset_utils/test_list_files.py b/tests/dataset/dataset_utils/test_list_files.py index 595fa4610..971a99ace 100644 --- a/tests/dataset/dataset_utils/test_list_files.py +++ b/tests/dataset/dataset_utils/test_list_files.py @@ -92,6 +92,10 @@ def test_list_multiple_files_tiff(tmp_path: Path): assert len(files) == 3 assert set(files) == set(ref_files) + # test that the files are sorted + assert files != ref_files + assert files == sorted(ref_files) + def test_list_single_file_custom(tmp_path): """Test listing a single custom file.""" From 22ab00c40b3dba19c51285ada7da98a1255052d2 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:41:40 +0200 Subject: [PATCH 02/14] (fix): use available memory rather than total for in memory switch --- src/careamics/utils/ram.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/careamics/utils/ram.py b/src/careamics/utils/ram.py index 258ebc824..dfa84456a 100644 --- a/src/careamics/utils/ram.py +++ b/src/careamics/utils/ram.py @@ -5,11 +5,11 @@ def get_ram_size() -> int: """ - Get RAM size in bytes. + Get RAM size in mbytes. Returns ------- int RAM size in mbytes. """ - return psutil.virtual_memory().total / 1024**2 + return psutil.virtual_memory().available / 1024**2 From 86e7d2283b866e5b9854c4409178a7fec2cb162c Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:10:40 +0200 Subject: [PATCH 03/14] (refac): create tiled prediction datasets --- src/careamics/dataset/__init__.py | 21 ++- src/careamics/dataset/in_memory_dataset.py | 125 +++++++++++-- src/careamics/dataset/iterable_dataset.py | 170 +++++++++++++----- .../lightning_prediction_datamodule.py | 53 ++++-- src/careamics/lightning_prediction_loop.py | 1 + 5 files changed, 293 insertions(+), 77 deletions(-) diff --git a/src/careamics/dataset/__init__.py b/src/careamics/dataset/__init__.py index b3c9cdbaf..b33d8092b 100644 --- a/src/careamics/dataset/__init__.py +++ b/src/careamics/dataset/__init__.py @@ -1,6 +1,21 @@ """Dataset module.""" -__all__ = ["InMemoryDataset", "PathIterableDataset"] +__all__ = [ + "InMemoryDataset", + "InMemoryPredictionDataset", + "InMemoryTiledPredictionDataset", + "PathIterableDataset", + "IterableTiledPredictionDataset", + "IterablePredictionDataset", +] -from .in_memory_dataset import InMemoryDataset -from .iterable_dataset import PathIterableDataset +from .in_memory_dataset import ( + InMemoryDataset, + InMemoryPredictionDataset, + InMemoryTiledPredictionDataset, +) +from .iterable_dataset import ( + IterablePredictionDataset, + IterableTiledPredictionDataset, + PathIterableDataset, +) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 24ceb1e84..9c938f32e 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -273,8 +273,7 @@ def split_dataset( class InMemoryPredictionDataset(Dataset): - """ - Dataset storing data in memory and allowing generating patches from it. + """Simple prediction dataset returning images along the sample axis. Parameters ---------- @@ -321,6 +320,7 @@ def __init__( self.mean = self.pred_config.mean self.std = self.pred_config.std self.data_target = data_target + self.mean, self.std = self.pred_config.mean, self.pred_config.std # tiling only if both tile size and overlap are provided self.tiling = self.tile_size is not None and self.tile_overlap is not None @@ -328,6 +328,103 @@ def __init__( # read function self.read_source_func = read_source_func + # Reshape data + self.data = reshape_array(self.input_array, self.axes) + + # get transforms + self.patch_transform = Compose( + transform_list=[NormalizeModel(mean=self.mean, std=self.std)], + ) + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) + + def __getitem__(self, index: int) -> np.ndarray: + """ + Return the patch corresponding to the provided index. + + Parameters + ---------- + index : int + Index of the patch to return. + + Returns + ------- + np.ndarray + Transformed patch. + """ + return self.data[[index]] + + +class InMemoryTiledPredictionDataset(Dataset): + """Prediction dataset storing data in memory and returning tiles of each image. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + data_target : Optional[np.ndarray], optional + Target data, by default None. + read_source_func : Optional[Callable], optional + Read source function for custom types, by default read_tiff. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + inputs: np.ndarray, + data_target: Optional[np.ndarray] = None, + read_source_func: Optional[Callable] = read_tiff, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + data_target : Optional[np.ndarray], optional + Target data, by default None. + read_source_func : Optional[Callable], optional + Read source function for custom types, by default read_tiff. + + Raises + ------ + ValueError + If data_path is not a directory. + """ + if ( + prediction_config.tile_size is None + or prediction_config.tile_overlap is None + ): + raise ValueError( + "Tile size and overlap must be provided to use the tiled prediction " + "dataset." + ) + + self.pred_config = prediction_config + self.input_array = inputs + self.axes = self.pred_config.axes + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + self.mean = self.pred_config.mean + self.std = self.pred_config.std + self.data_target = data_target + + # read function + self.read_source_func = read_source_func + # Generate patches self.data = self._prepare_tiles() self.mean, self.std = self.pred_config.mean, self.pred_config.std @@ -349,22 +446,18 @@ def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: # reshape array reshaped_sample = reshape_array(self.input_array, self.axes) - if self.tiling and self.tile_size is not None and self.tile_overlap is not None: - # generate patches, which returns a generator - patch_generator = extract_tiles( - arr=reshaped_sample, - tile_size=self.tile_size, - overlaps=self.tile_overlap, - ) - patches_list = list(patch_generator) + # generate patches, which returns a generator + patch_generator = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + patches_list = list(patch_generator) - if len(patches_list) == 0: - raise ValueError("No tiles generated, ") + if len(patches_list) == 0: + raise ValueError("No tiles generated, ") - return patches_list - else: - array_shape = reshaped_sample.squeeze().shape - return [(reshaped_sample, TileInformation(array_shape=array_shape))] + return patches_list def __len__(self) -> int: """ diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 192e73864..2ac54ca4c 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -28,8 +28,7 @@ def _iterate_over_files( target_files: Optional[List[Path]] = None, read_source_func: Callable = read_tiff, ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: - """ - Iterate over data source and yield whole image. + """Iterate over data source and yield whole reshaped images. Parameters ---------- @@ -63,6 +62,9 @@ def _iterate_over_files( # read data sample = read_source_func(filename, data_config.axes) + # reshape array + reshaped_sample = reshape_array(sample, data_config.axes) + # read target, if available if target_files is not None: if filename.name != target_files[i].name: @@ -75,9 +77,12 @@ def _iterate_over_files( # read target target = read_source_func(target_files[i], data_config.axes) - yield sample, target + # reshape target + reshaped_target = reshape_array(target, data_config.axes) + + yield reshaped_sample, reshaped_target else: - yield sample, None + yield reshaped_sample, None except Exception as e: logger.error(f"Error reading file {filename}: {e}") @@ -206,17 +211,10 @@ def __iter__( for sample_input, sample_target in _iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): - reshaped_sample = reshape_array(sample_input, self.data_config.axes) - reshaped_target = ( - None - if sample_target is None - else reshape_array(sample_target, self.data_config.axes) - ) - patches = extract_patches_random( - arr=reshaped_sample, + arr=sample_input, patch_size=self.data_config.patch_size, - target=reshaped_target, + target=sample_target, ) # iterate over patches @@ -320,8 +318,7 @@ def split_dataset( class IterablePredictionDataset(IterableDataset): - """ - Prediction dataset. + """Simple iterable prediction dataset. Parameters ---------- @@ -376,13 +373,8 @@ def __init__( self.prediction_config = prediction_config self.data_files = src_files self.axes = prediction_config.axes - self.tile_size = self.prediction_config.tile_size - self.tile_overlap = self.prediction_config.tile_overlap self.read_source_func = read_source_func - # tile only if both tile size and overlaps are provided - self.tile = self.tile_size is not None and self.tile_overlap is not None - # check mean and std and create normalize transform if self.prediction_config.mean is None or self.prediction_config.std is None: raise ValueError("Mean and std must be provided for prediction.") @@ -401,7 +393,7 @@ def __init__( def __iter__( self, - ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: + ) -> Generator[np.ndarray, None, None]: """ Iterate over data source and yield single patch. @@ -419,27 +411,121 @@ def __iter__( self.data_files, read_source_func=self.read_source_func, ): - # reshape array - reshaped_sample = reshape_array(sample, self.axes) - - if ( - self.tile - and self.tile_size is not None - and self.tile_overlap is not None - ): - # generate patches, return a generator - patch_gen = extract_tiles( - arr=reshaped_sample, - tile_size=self.tile_size, - overlaps=self.tile_overlap, - ) - else: - # just wrap the sample in a generator with default tiling info - array_shape = reshaped_sample.squeeze().shape - patch_gen = ( - (reshaped_sample, TileInformation(array_shape=array_shape)) - for _ in range(1) - ) + # TODO what if S dimensions > 1, should we yield each sample independently? + transformed_sample, _ = self.patch_transform(patch=sample) + yield transformed_sample + + +class IterableTiledPredictionDataset(IterableDataset): + """Tiled prediction dataset. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Attributes + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + src_files: List[Path], + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Raises + ------ + ValueError + If mean and std are not provided in the inference configuration. + """ + if ( + prediction_config.tile_size is None + or prediction_config.tile_overlap is None + ): + raise ValueError( + "Tile size and overlap must be provided for tiled prediction." + ) + + self.prediction_config = prediction_config + self.data_files = src_files + self.axes = prediction_config.axes + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + self.read_source_func = read_source_func + + # check mean and std and create normalize transform + if self.prediction_config.mean is None or self.prediction_config.std is None: + raise ValueError("Mean and std must be provided for prediction.") + else: + self.mean = self.prediction_config.mean + self.std = self.prediction_config.std + + # instantiate normalize transform + self.patch_transform = Compose( + transform_list=[ + NormalizeModel( + mean=prediction_config.mean, std=prediction_config.std + ) + ], + ) + + def __iter__( + self, + ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + Tuple[pnp.ndarray, TileInformation] + Single tile. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + for sample, _ in _iterate_over_files( + self.prediction_config, + self.data_files, + read_source_func=self.read_source_func, + ): + # generate patches, return a generator + patch_gen = extract_tiles( + arr=sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) # apply transform to patches for patch_array, tile_info in patch_gen: diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index b7a730faa..b22b26a72 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -17,13 +17,19 @@ ) from careamics.dataset.in_memory_dataset import ( InMemoryPredictionDataset, + InMemoryTiledPredictionDataset, ) from careamics.dataset.iterable_dataset import ( IterablePredictionDataset, + IterableTiledPredictionDataset, ) from careamics.utils import get_logger -PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] +PredictDatasetType = Union[ + InMemoryPredictionDataset, + InMemoryTiledPredictionDataset, + IterableTiledPredictionDataset, +] logger = get_logger(__name__) @@ -50,10 +56,8 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: first_tile_info: TileInformation = batch[0][1] # if not tiled, then return arrays if not first_tile_info.tiled: - arrays, _ = zip(*batch) - - return default_collate(arrays) - # else we explicit the last_tile flag and coordinates + raise ValueError("Collate function should not be called for non-tiled data.") + # else we explicit the last_tile flag and coordinates for the the default collate else: new_batch = [ (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords) @@ -182,6 +186,9 @@ def __init__( self.tile_size = pred_config.tile_size self.tile_overlap = pred_config.tile_overlap + # check if it is tiled + self.tiled = self.tile_size is not None and self.tile_overlap is not None + # read source function if pred_config.data_type == SupportedData.CUSTOM: # mypy check @@ -212,17 +219,31 @@ def setup(self, stage: Optional[str] = None) -> None: """ # if numpy array if self.data_type == SupportedData.ARRAY: - # prediction dataset - self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( - prediction_config=self.prediction_config, - inputs=self.pred_data, - ) + if self.tiled: + self.predict_dataset: PredictDatasetType = ( + InMemoryTiledPredictionDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) + ) + else: + self.predict_dataset = InMemoryPredictionDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) else: - self.predict_dataset = IterablePredictionDataset( - prediction_config=self.prediction_config, - src_files=self.pred_files, - read_source_func=self.read_source_func, - ) + if self.tiled: + self.predict_dataset = IterableTiledPredictionDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) + else: + self.predict_dataset = IterablePredictionDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) def predict_dataloader(self) -> DataLoader: """ @@ -236,7 +257,7 @@ def predict_dataloader(self) -> DataLoader: return DataLoader( self.predict_dataset, batch_size=self.batch_size, - collate_fn=_collate_tiles, + collate_fn=_collate_tiles if self.tiled else None, **self.dataloader_params, ) # TODO check workers are used diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py index 46d41ca99..af3ad3ed2 100644 --- a/src/careamics/lightning_prediction_loop.py +++ b/src/careamics/lightning_prediction_loop.py @@ -87,6 +87,7 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: ######################################################## ################ CAREamics specific code ############### + # TODO: not compatible with multiple outputs (e.g. muSplit) is_tiled = len(self.predictions[batch_idx]) == 2 if is_tiled: # extract the last tile flag and the coordinates (crop and stitch) From ce8c2629cfc866a31609fb3b0d3e9ed21bba4fe2 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:41:54 +0200 Subject: [PATCH 04/14] (chore): add todo --- src/careamics/careamist.py | 1 + tests/model_io/test_bmz_io.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 62312131e..a3eba2c69 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -682,6 +682,7 @@ def export_to_bmz( data_description : Optional[str], optional Description of the data, by default None. """ + # TODO data from dataloader will be normalized, which is an issue with BMZ! if input_array is None: # generate images, priority is given to the prediction data module if self.pred_datamodule is not None: diff --git a/tests/model_io/test_bmz_io.py b/tests/model_io/test_bmz_io.py index a031aa0d2..bf48b2e11 100644 --- a/tests/model_io/test_bmz_io.py +++ b/tests/model_io/test_bmz_io.py @@ -33,7 +33,7 @@ def test_state_dict_io(tmp_path, pre_trained): def test_bmz_io(tmp_path, pre_trained): """Test exporting and loading to the BMZ.""" # training data - train_array = np.ones((32, 32), dtype=np.float32) + train_array = np.ones((16, 16), dtype=np.float32) # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) From 68ab47dd945a9f0705ae07fe6b34085e123b6c06 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:00:02 +0200 Subject: [PATCH 05/14] (fix): add transforms in prediction dataset, fix tests --- src/careamics/dataset/in_memory_dataset.py | 4 +++- src/careamics/dataset/iterable_dataset.py | 1 + tests/test_careamist.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 9c938f32e..3e937c124 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -361,7 +361,9 @@ def __getitem__(self, index: int) -> np.ndarray: np.ndarray Transformed patch. """ - return self.data[[index]] + transformed_patch, _ = self.patch_transform(patch=self.data[[index]]) + + return transformed_patch class InMemoryTiledPredictionDataset(Dataset): diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 2ac54ca4c..f4df67906 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -413,6 +413,7 @@ def __iter__( ): # TODO what if S dimensions > 1, should we yield each sample independently? transformed_sample, _ = self.patch_transform(patch=sample) + yield transformed_sample diff --git a/tests/test_careamist.py b/tests/test_careamist.py index d22062cb4..469eb02cf 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -546,8 +546,9 @@ def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): # predict CAREamist predicted = careamist.predict(train_array) + predicted_squeeze = [p.squeeze() for p in predicted] - assert predicted.squeeze().shape == train_array.shape + assert np.array(predicted_squeeze).shape == train_array.shape # export to BMZ careamist.export_to_bmz( From 5770d7c23329eb949ffb3a4e8f1dade0b72e4add Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 5 Jun 2024 16:03:47 +0200 Subject: [PATCH 06/14] (refac): split the datasets into their own modules --- src/careamics/dataset/__init__.py | 18 +- .../dataset/dataset_utils/__init__.py | 2 + .../dataset_utils/iterate_over_files.py | 83 +++++ src/careamics/dataset/in_memory_dataset.py | 229 +------------- .../dataset/in_memory_pred_dataset.py | 108 +++++++ .../dataset/in_memory_tiled_pred_dataset.py | 144 +++++++++ src/careamics/dataset/iterable_dataset.py | 299 +----------------- .../dataset/iterable_pred_dataset.py | 115 +++++++ .../dataset/iterable_tiled_pred_dataset.py | 135 ++++++++ .../lightning_prediction_datamodule.py | 19 +- 10 files changed, 611 insertions(+), 541 deletions(-) create mode 100644 src/careamics/dataset/dataset_utils/iterate_over_files.py create mode 100644 src/careamics/dataset/in_memory_pred_dataset.py create mode 100644 src/careamics/dataset/in_memory_tiled_pred_dataset.py create mode 100644 src/careamics/dataset/iterable_pred_dataset.py create mode 100644 src/careamics/dataset/iterable_tiled_pred_dataset.py diff --git a/src/careamics/dataset/__init__.py b/src/careamics/dataset/__init__.py index b33d8092b..c29d1097d 100644 --- a/src/careamics/dataset/__init__.py +++ b/src/careamics/dataset/__init__.py @@ -2,20 +2,16 @@ __all__ = [ "InMemoryDataset", - "InMemoryPredictionDataset", + "InMemoryPredDataset", "InMemoryTiledPredictionDataset", "PathIterableDataset", "IterableTiledPredictionDataset", "IterablePredictionDataset", ] -from .in_memory_dataset import ( - InMemoryDataset, - InMemoryPredictionDataset, - InMemoryTiledPredictionDataset, -) -from .iterable_dataset import ( - IterablePredictionDataset, - IterableTiledPredictionDataset, - PathIterableDataset, -) +from .in_memory_dataset import InMemoryDataset +from .in_memory_pred_dataset import InMemoryPredDataset +from .in_memory_tiled_pred_dataset import InMemoryTiledPredictionDataset +from .iterable_dataset import PathIterableDataset +from .iterable_pred_dataset import IterablePredictionDataset +from .iterable_tiled_pred_dataset import IterableTiledPredictionDataset diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index 242405769..e8e93692e 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -8,11 +8,13 @@ "read_tiff", "get_read_func", "read_zarr", + "iterate_over_files", ] from .dataset_utils import reshape_array from .file_utils import get_files_size, list_files, validate_source_target_files +from .iterate_over_files import iterate_over_files from .read_tiff import read_tiff from .read_utils import get_read_func from .read_zarr import read_zarr diff --git a/src/careamics/dataset/dataset_utils/iterate_over_files.py b/src/careamics/dataset/dataset_utils/iterate_over_files.py new file mode 100644 index 000000000..b3e413f9b --- /dev/null +++ b/src/careamics/dataset/dataset_utils/iterate_over_files.py @@ -0,0 +1,83 @@ +"""Function to iterate over files.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Callable, Generator, List, Optional, Tuple, Union + +import numpy as np +from torch.utils.data import get_worker_info + +from careamics.config import DataConfig, InferenceConfig +from careamics.utils.logging import get_logger + +from .dataset_utils import reshape_array +from .read_tiff import read_tiff + +logger = get_logger(__name__) + + +def iterate_over_files( + data_config: Union[DataConfig, InferenceConfig], + data_files: List[Path], + target_files: Optional[List[Path]] = None, + read_source_func: Callable = read_tiff, +) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: + """Iterate over data source and yield whole reshaped images. + + Parameters + ---------- + data_config : Union[DataConfig, InferenceConfig] + Data configuration. + data_files : List[Path] + List of data files. + target_files : Optional[List[Path]] + List of target files, by default None. + read_source_func : Optional[Callable] + Function to read the source, by default read_tiff. + + Yields + ------ + np.ndarray + Image. + """ + # When num_workers > 0, each worker process will have a different copy of the + # dataset object + # Configuring each copy independently to avoid having duplicate data returned + # from the workers + worker_info = get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + num_workers = worker_info.num_workers if worker_info is not None else 1 + + # iterate over the files + for i, filename in enumerate(data_files): + # retrieve file corresponding to the worker id + if i % num_workers == worker_id: + try: + # read data + sample = read_source_func(filename, data_config.axes) + + # reshape array + reshaped_sample = reshape_array(sample, data_config.axes) + + # read target, if available + if target_files is not None: + if filename.name != target_files[i].name: + raise ValueError( + f"File {filename} does not match target file " + f"{target_files[i]}. Have you passed sorted " + f"arrays?" + ) + + # read target + target = read_source_func(target_files[i], data_config.axes) + + # reshape target + reshaped_target = reshape_array(target, data_config.axes) + + yield reshaped_sample, reshaped_target + else: + yield reshaped_sample, None + + except Exception as e: + logger.error(f"Error reading file {filename}: {e}") diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 3e937c124..97c6bd04f 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -11,18 +11,15 @@ from careamics.transforms import Compose -from ..config import DataConfig, InferenceConfig -from ..config.tile_information import TileInformation -from ..config.transformations import NormalizeModel +from ..config import DataConfig from ..utils.logging import get_logger -from .dataset_utils import read_tiff, reshape_array +from .dataset_utils import read_tiff from .patching.patching import ( prepare_patches_supervised, prepare_patches_supervised_array, prepare_patches_unsupervised, prepare_patches_unsupervised_array, ) -from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) @@ -270,225 +267,3 @@ def split_dataset( dataset.patch_targets = val_targets return dataset - - -class InMemoryPredictionDataset(Dataset): - """Simple prediction dataset returning images along the sample axis. - - Parameters - ---------- - prediction_config : InferenceConfig - Prediction configuration. - inputs : np.ndarray - Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. - """ - - def __init__( - self, - prediction_config: InferenceConfig, - inputs: np.ndarray, - data_target: Optional[np.ndarray] = None, - read_source_func: Optional[Callable] = read_tiff, - ) -> None: - """Constructor. - - Parameters - ---------- - prediction_config : InferenceConfig - Prediction configuration. - inputs : np.ndarray - Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. - - Raises - ------ - ValueError - If data_path is not a directory. - """ - self.pred_config = prediction_config - self.input_array = inputs - self.axes = self.pred_config.axes - self.tile_size = self.pred_config.tile_size - self.tile_overlap = self.pred_config.tile_overlap - self.mean = self.pred_config.mean - self.std = self.pred_config.std - self.data_target = data_target - self.mean, self.std = self.pred_config.mean, self.pred_config.std - - # tiling only if both tile size and overlap are provided - self.tiling = self.tile_size is not None and self.tile_overlap is not None - - # read function - self.read_source_func = read_source_func - - # Reshape data - self.data = reshape_array(self.input_array, self.axes) - - # get transforms - self.patch_transform = Compose( - transform_list=[NormalizeModel(mean=self.mean, std=self.std)], - ) - - def __len__(self) -> int: - """ - Return the length of the dataset. - - Returns - ------- - int - Length of the dataset. - """ - return len(self.data) - - def __getitem__(self, index: int) -> np.ndarray: - """ - Return the patch corresponding to the provided index. - - Parameters - ---------- - index : int - Index of the patch to return. - - Returns - ------- - np.ndarray - Transformed patch. - """ - transformed_patch, _ = self.patch_transform(patch=self.data[[index]]) - - return transformed_patch - - -class InMemoryTiledPredictionDataset(Dataset): - """Prediction dataset storing data in memory and returning tiles of each image. - - Parameters - ---------- - prediction_config : InferenceConfig - Prediction configuration. - inputs : np.ndarray - Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. - """ - - def __init__( - self, - prediction_config: InferenceConfig, - inputs: np.ndarray, - data_target: Optional[np.ndarray] = None, - read_source_func: Optional[Callable] = read_tiff, - ) -> None: - """Constructor. - - Parameters - ---------- - prediction_config : InferenceConfig - Prediction configuration. - inputs : np.ndarray - Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. - - Raises - ------ - ValueError - If data_path is not a directory. - """ - if ( - prediction_config.tile_size is None - or prediction_config.tile_overlap is None - ): - raise ValueError( - "Tile size and overlap must be provided to use the tiled prediction " - "dataset." - ) - - self.pred_config = prediction_config - self.input_array = inputs - self.axes = self.pred_config.axes - self.tile_size = prediction_config.tile_size - self.tile_overlap = prediction_config.tile_overlap - self.mean = self.pred_config.mean - self.std = self.pred_config.std - self.data_target = data_target - - # read function - self.read_source_func = read_source_func - - # Generate patches - self.data = self._prepare_tiles() - self.mean, self.std = self.pred_config.mean, self.pred_config.std - - # get transforms - self.patch_transform = Compose( - transform_list=[NormalizeModel(mean=self.mean, std=self.std)], - ) - - def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: - """ - Iterate over data source and create an array of patches. - - Returns - ------- - List[XArrayTile] - List of tiles. - """ - # reshape array - reshaped_sample = reshape_array(self.input_array, self.axes) - - # generate patches, which returns a generator - patch_generator = extract_tiles( - arr=reshaped_sample, - tile_size=self.tile_size, - overlaps=self.tile_overlap, - ) - patches_list = list(patch_generator) - - if len(patches_list) == 0: - raise ValueError("No tiles generated, ") - - return patches_list - - def __len__(self) -> int: - """ - Return the length of the dataset. - - Returns - ------- - int - Length of the dataset. - """ - return len(self.data) - - def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: - """ - Return the patch corresponding to the provided index. - - Parameters - ---------- - index : int - Index of the patch to return. - - Returns - ------- - Tuple[np.ndarray, TileInformation] - Transformed patch. - """ - tile_array, tile_info = self.data[index] - - # Apply transforms - transformed_tile, _ = self.patch_transform(patch=tile_array) - - return transformed_tile, tile_info diff --git a/src/careamics/dataset/in_memory_pred_dataset.py b/src/careamics/dataset/in_memory_pred_dataset.py new file mode 100644 index 000000000..5d3bbcb17 --- /dev/null +++ b/src/careamics/dataset/in_memory_pred_dataset.py @@ -0,0 +1,108 @@ +"""In-memory prediction dataset.""" + +from __future__ import annotations + +from typing import Callable, Optional + +import numpy as np +from torch.utils.data import Dataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.transformations import NormalizeModel +from .dataset_utils import read_tiff, reshape_array + + +class InMemoryPredDataset(Dataset): + """Simple prediction dataset returning images along the sample axis. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + data_target : Optional[np.ndarray], optional + Target data, by default None. + read_source_func : Optional[Callable], optional + Read source function for custom types, by default read_tiff. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + inputs: np.ndarray, + data_target: Optional[np.ndarray] = None, + read_source_func: Optional[Callable] = read_tiff, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + data_target : Optional[np.ndarray], optional + Target data, by default None. + read_source_func : Optional[Callable], optional + Read source function for custom types, by default read_tiff. + + Raises + ------ + ValueError + If data_path is not a directory. + """ + self.pred_config = prediction_config + self.input_array = inputs + self.axes = self.pred_config.axes + self.tile_size = self.pred_config.tile_size + self.tile_overlap = self.pred_config.tile_overlap + self.mean = self.pred_config.mean + self.std = self.pred_config.std + self.data_target = data_target + self.mean, self.std = self.pred_config.mean, self.pred_config.std + + # tiling only if both tile size and overlap are provided + self.tiling = self.tile_size is not None and self.tile_overlap is not None + + # read function + self.read_source_func = read_source_func + + # Reshape data + self.data = reshape_array(self.input_array, self.axes) + + # get transforms + self.patch_transform = Compose( + transform_list=[NormalizeModel(mean=self.mean, std=self.std)], + ) + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) + + def __getitem__(self, index: int) -> np.ndarray: + """ + Return the patch corresponding to the provided index. + + Parameters + ---------- + index : int + Index of the patch to return. + + Returns + ------- + np.ndarray + Transformed patch. + """ + transformed_patch, _ = self.patch_transform(patch=self.data[[index]]) + + return transformed_patch diff --git a/src/careamics/dataset/in_memory_tiled_pred_dataset.py b/src/careamics/dataset/in_memory_tiled_pred_dataset.py new file mode 100644 index 000000000..507feb8c1 --- /dev/null +++ b/src/careamics/dataset/in_memory_tiled_pred_dataset.py @@ -0,0 +1,144 @@ +"""In-memory tiled prediction dataset.""" + +from __future__ import annotations + +from typing import Callable, List, Optional, Tuple + +import numpy as np +from torch.utils.data import Dataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.tile_information import TileInformation +from ..config.transformations import NormalizeModel +from .dataset_utils import read_tiff, reshape_array +from .patching.tiled_patching import extract_tiles + + +class InMemoryTiledPredictionDataset(Dataset): + """Prediction dataset storing data in memory and returning tiles of each image. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + data_target : Optional[np.ndarray], optional + Target data, by default None. + read_source_func : Optional[Callable], optional + Read source function for custom types, by default read_tiff. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + inputs: np.ndarray, + data_target: Optional[np.ndarray] = None, + read_source_func: Optional[Callable] = read_tiff, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Prediction configuration. + inputs : np.ndarray + Input data. + data_target : Optional[np.ndarray], optional + Target data, by default None. + read_source_func : Optional[Callable], optional + Read source function for custom types, by default read_tiff. + + Raises + ------ + ValueError + If data_path is not a directory. + """ + if ( + prediction_config.tile_size is None + or prediction_config.tile_overlap is None + ): + raise ValueError( + "Tile size and overlap must be provided to use the tiled prediction " + "dataset." + ) + + self.pred_config = prediction_config + self.input_array = inputs + self.axes = self.pred_config.axes + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + self.mean = self.pred_config.mean + self.std = self.pred_config.std + self.data_target = data_target + + # read function + self.read_source_func = read_source_func + + # Generate patches + self.data = self._prepare_tiles() + self.mean, self.std = self.pred_config.mean, self.pred_config.std + + # get transforms + self.patch_transform = Compose( + transform_list=[NormalizeModel(mean=self.mean, std=self.std)], + ) + + def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: + """ + Iterate over data source and create an array of patches. + + Returns + ------- + List[XArrayTile] + List of tiles. + """ + # reshape array + reshaped_sample = reshape_array(self.input_array, self.axes) + + # generate patches, which returns a generator + patch_generator = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + patches_list = list(patch_generator) + + if len(patches_list) == 0: + raise ValueError("No tiles generated, ") + + return patches_list + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: + """ + Return the patch corresponding to the provided index. + + Parameters + ---------- + index : int + Index of the patch to return. + + Returns + ------- + Tuple[np.ndarray, TileInformation] + Transformed patch. + """ + tile_array, tile_info = self.data[index] + + # Apply transforms + transformed_tile, _ = self.patch_transform(patch=tile_array) + + return transformed_tile, tile_info diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index f4df67906..6babe7568 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -4,90 +4,21 @@ import copy from pathlib import Path -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Tuple import numpy as np -from torch.utils.data import IterableDataset, get_worker_info +from torch.utils.data import IterableDataset +from careamics.config import DataConfig from careamics.transforms import Compose -from ..config import DataConfig, InferenceConfig -from ..config.tile_information import TileInformation -from ..config.transformations import NormalizeModel from ..utils.logging import get_logger -from .dataset_utils import read_tiff, reshape_array +from .dataset_utils import iterate_over_files, read_tiff from .patching.random_patching import extract_patches_random -from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) -def _iterate_over_files( - data_config: Union[DataConfig, InferenceConfig], - data_files: List[Path], - target_files: Optional[List[Path]] = None, - read_source_func: Callable = read_tiff, -) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: - """Iterate over data source and yield whole reshaped images. - - Parameters - ---------- - data_config : Union[DataConfig, InferenceConfig] - Data configuration. - data_files : List[Path] - List of data files. - target_files : Optional[List[Path]] - List of target files, by default None. - read_source_func : Optional[Callable] - Function to read the source, by default read_tiff. - - Yields - ------ - np.ndarray - Image. - """ - # When num_workers > 0, each worker process will have a different copy of the - # dataset object - # Configuring each copy independently to avoid having duplicate data returned - # from the workers - worker_info = get_worker_info() - worker_id = worker_info.id if worker_info is not None else 0 - num_workers = worker_info.num_workers if worker_info is not None else 1 - - # iterate over the files - for i, filename in enumerate(data_files): - # retrieve file corresponding to the worker id - if i % num_workers == worker_id: - try: - # read data - sample = read_source_func(filename, data_config.axes) - - # reshape array - reshaped_sample = reshape_array(sample, data_config.axes) - - # read target, if available - if target_files is not None: - if filename.name != target_files[i].name: - raise ValueError( - f"File {filename} does not match target file " - f"{target_files[i]}. Have you passed sorted " - f"arrays?" - ) - - # read target - target = read_source_func(target_files[i], data_config.axes) - - # reshape target - reshaped_target = reshape_array(target, data_config.axes) - - yield reshaped_sample, reshaped_target - else: - yield reshaped_sample, None - - except Exception as e: - logger.error(f"Error reading file {filename}: {e}") - - class PathIterableDataset(IterableDataset): """ Dataset allowing extracting patches w/o loading whole data into memory. @@ -175,7 +106,7 @@ def _calculate_mean_and_std(self) -> Tuple[float, float]: means, stds = 0, 0 num_samples = 0 - for sample, _ in _iterate_over_files( + for sample, _ in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): means += sample.mean() @@ -208,7 +139,7 @@ def __iter__( ), "Mean and std must be provided" # iterate over files - for sample_input, sample_target in _iterate_over_files( + for sample_input, sample_target in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): patches = extract_patches_random( @@ -315,221 +246,3 @@ def split_dataset( dataset.target_files = val_target_files return dataset - - -class IterablePredictionDataset(IterableDataset): - """Simple iterable prediction dataset. - - Parameters - ---------- - prediction_config : InferenceConfig - Inference configuration. - src_files : List[Path] - List of data files. - read_source_func : Callable, optional - Read source function for custom types, by default read_tiff. - **kwargs : Any - Additional keyword arguments, unused. - - Attributes - ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - axes : str - Description of axes in format STCZYX. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - """ - - def __init__( - self, - prediction_config: InferenceConfig, - src_files: List[Path], - read_source_func: Callable = read_tiff, - **kwargs: Any, - ) -> None: - """Constructor. - - Parameters - ---------- - prediction_config : InferenceConfig - Inference configuration. - src_files : List[Path] - List of data files. - read_source_func : Callable, optional - Read source function for custom types, by default read_tiff. - **kwargs : Any - Additional keyword arguments, unused. - - Raises - ------ - ValueError - If mean and std are not provided in the inference configuration. - """ - self.prediction_config = prediction_config - self.data_files = src_files - self.axes = prediction_config.axes - self.read_source_func = read_source_func - - # check mean and std and create normalize transform - if self.prediction_config.mean is None or self.prediction_config.std is None: - raise ValueError("Mean and std must be provided for prediction.") - else: - self.mean = self.prediction_config.mean - self.std = self.prediction_config.std - - # instantiate normalize transform - self.patch_transform = Compose( - transform_list=[ - NormalizeModel( - mean=prediction_config.mean, std=prediction_config.std - ) - ], - ) - - def __iter__( - self, - ) -> Generator[np.ndarray, None, None]: - """ - Iterate over data source and yield single patch. - - Yields - ------ - np.ndarray - Single patch. - """ - assert ( - self.mean is not None and self.std is not None - ), "Mean and std must be provided" - - for sample, _ in _iterate_over_files( - self.prediction_config, - self.data_files, - read_source_func=self.read_source_func, - ): - # TODO what if S dimensions > 1, should we yield each sample independently? - transformed_sample, _ = self.patch_transform(patch=sample) - - yield transformed_sample - - -class IterableTiledPredictionDataset(IterableDataset): - """Tiled prediction dataset. - - Parameters - ---------- - prediction_config : InferenceConfig - Inference configuration. - src_files : List[Path] - List of data files. - read_source_func : Callable, optional - Read source function for custom types, by default read_tiff. - **kwargs : Any - Additional keyword arguments, unused. - - Attributes - ---------- - data_path : Union[str, Path] - Path to the data, must be a directory. - axes : str - Description of axes in format STCZYX. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - """ - - def __init__( - self, - prediction_config: InferenceConfig, - src_files: List[Path], - read_source_func: Callable = read_tiff, - **kwargs: Any, - ) -> None: - """Constructor. - - Parameters - ---------- - prediction_config : InferenceConfig - Inference configuration. - src_files : List[Path] - List of data files. - read_source_func : Callable, optional - Read source function for custom types, by default read_tiff. - **kwargs : Any - Additional keyword arguments, unused. - - Raises - ------ - ValueError - If mean and std are not provided in the inference configuration. - """ - if ( - prediction_config.tile_size is None - or prediction_config.tile_overlap is None - ): - raise ValueError( - "Tile size and overlap must be provided for tiled prediction." - ) - - self.prediction_config = prediction_config - self.data_files = src_files - self.axes = prediction_config.axes - self.tile_size = prediction_config.tile_size - self.tile_overlap = prediction_config.tile_overlap - self.read_source_func = read_source_func - - # check mean and std and create normalize transform - if self.prediction_config.mean is None or self.prediction_config.std is None: - raise ValueError("Mean and std must be provided for prediction.") - else: - self.mean = self.prediction_config.mean - self.std = self.prediction_config.std - - # instantiate normalize transform - self.patch_transform = Compose( - transform_list=[ - NormalizeModel( - mean=prediction_config.mean, std=prediction_config.std - ) - ], - ) - - def __iter__( - self, - ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: - """ - Iterate over data source and yield single patch. - - Yields - ------ - Tuple[pnp.ndarray, TileInformation] - Single tile. - """ - assert ( - self.mean is not None and self.std is not None - ), "Mean and std must be provided" - - for sample, _ in _iterate_over_files( - self.prediction_config, - self.data_files, - read_source_func=self.read_source_func, - ): - # generate patches, return a generator - patch_gen = extract_tiles( - arr=sample, - tile_size=self.tile_size, - overlaps=self.tile_overlap, - ) - - # apply transform to patches - for patch_array, tile_info in patch_gen: - transformed_patch, _ = self.patch_transform(patch=patch_array) - - yield transformed_patch, tile_info diff --git a/src/careamics/dataset/iterable_pred_dataset.py b/src/careamics/dataset/iterable_pred_dataset.py new file mode 100644 index 000000000..d80e02c91 --- /dev/null +++ b/src/careamics/dataset/iterable_pred_dataset.py @@ -0,0 +1,115 @@ +"""Iterable prediction dataset used to load data file by file.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Generator, List + +import numpy as np +from torch.utils.data import IterableDataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.transformations import NormalizeModel +from .dataset_utils import iterate_over_files, read_tiff + + +class IterablePredictionDataset(IterableDataset): + """Simple iterable prediction dataset. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Attributes + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + src_files: List[Path], + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Raises + ------ + ValueError + If mean and std are not provided in the inference configuration. + """ + self.prediction_config = prediction_config + self.data_files = src_files + self.axes = prediction_config.axes + self.read_source_func = read_source_func + + # check mean and std and create normalize transform + if self.prediction_config.mean is None or self.prediction_config.std is None: + raise ValueError("Mean and std must be provided for prediction.") + else: + self.mean = self.prediction_config.mean + self.std = self.prediction_config.std + + # instantiate normalize transform + self.patch_transform = Compose( + transform_list=[ + NormalizeModel( + mean=prediction_config.mean, std=prediction_config.std + ) + ], + ) + + def __iter__( + self, + ) -> Generator[np.ndarray, None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + np.ndarray + Single patch. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + for sample, _ in iterate_over_files( + self.prediction_config, + self.data_files, + read_source_func=self.read_source_func, + ): + # TODO what if S dimensions > 1, should we yield each sample independently? + transformed_sample, _ = self.patch_transform(patch=sample) + + yield transformed_sample diff --git a/src/careamics/dataset/iterable_tiled_pred_dataset.py b/src/careamics/dataset/iterable_tiled_pred_dataset.py new file mode 100644 index 000000000..a8f888b8a --- /dev/null +++ b/src/careamics/dataset/iterable_tiled_pred_dataset.py @@ -0,0 +1,135 @@ +"""Iterable tiled prediction dataset used to load data file by file.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Generator, List, Tuple + +import numpy as np +from torch.utils.data import IterableDataset + +from careamics.transforms import Compose + +from ..config import InferenceConfig +from ..config.tile_information import TileInformation +from ..config.transformations import NormalizeModel +from .dataset_utils import iterate_over_files, read_tiff +from .patching.tiled_patching import extract_tiles + + +class IterableTiledPredictionDataset(IterableDataset): + """Tiled prediction dataset. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Attributes + ---------- + data_path : Union[str, Path] + Path to the data, must be a directory. + axes : str + Description of axes in format STCZYX. + mean : Optional[float], optional + Expected mean of the dataset, by default None. + std : Optional[float], optional + Expected standard deviation of the dataset, by default None. + patch_transform : Optional[Callable], optional + Patch transform callable, by default None. + """ + + def __init__( + self, + prediction_config: InferenceConfig, + src_files: List[Path], + read_source_func: Callable = read_tiff, + **kwargs: Any, + ) -> None: + """Constructor. + + Parameters + ---------- + prediction_config : InferenceConfig + Inference configuration. + src_files : List[Path] + List of data files. + read_source_func : Callable, optional + Read source function for custom types, by default read_tiff. + **kwargs : Any + Additional keyword arguments, unused. + + Raises + ------ + ValueError + If mean and std are not provided in the inference configuration. + """ + if ( + prediction_config.tile_size is None + or prediction_config.tile_overlap is None + ): + raise ValueError( + "Tile size and overlap must be provided for tiled prediction." + ) + + self.prediction_config = prediction_config + self.data_files = src_files + self.axes = prediction_config.axes + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + self.read_source_func = read_source_func + + # check mean and std and create normalize transform + if self.prediction_config.mean is None or self.prediction_config.std is None: + raise ValueError("Mean and std must be provided for prediction.") + else: + self.mean = self.prediction_config.mean + self.std = self.prediction_config.std + + # instantiate normalize transform + self.patch_transform = Compose( + transform_list=[ + NormalizeModel( + mean=prediction_config.mean, std=prediction_config.std + ) + ], + ) + + def __iter__( + self, + ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: + """ + Iterate over data source and yield single patch. + + Yields + ------ + Tuple[pnp.ndarray, TileInformation] + Single tile. + """ + assert ( + self.mean is not None and self.std is not None + ), "Mean and std must be provided" + + for sample, _ in iterate_over_files( + self.prediction_config, + self.data_files, + read_source_func=self.read_source_func, + ): + # generate patches, return a generator + patch_gen = extract_tiles( + arr=sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + + # apply transform to patches + for patch_array, tile_info in patch_gen: + transformed_patch, _ = self.patch_transform(patch=patch_array) + + yield transformed_patch, tile_info diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index b22b26a72..4423d6294 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -11,23 +11,22 @@ from careamics.config import InferenceConfig from careamics.config.support import SupportedData from careamics.config.tile_information import TileInformation -from careamics.dataset.dataset_utils import ( - get_read_func, - list_files, -) -from careamics.dataset.in_memory_dataset import ( - InMemoryPredictionDataset, +from careamics.dataset import ( + InMemoryPredDataset, InMemoryTiledPredictionDataset, -) -from careamics.dataset.iterable_dataset import ( IterablePredictionDataset, IterableTiledPredictionDataset, ) +from careamics.dataset.dataset_utils import ( + get_read_func, + list_files, +) from careamics.utils import get_logger PredictDatasetType = Union[ - InMemoryPredictionDataset, + InMemoryPredDataset, InMemoryTiledPredictionDataset, + IterablePredictionDataset, IterableTiledPredictionDataset, ] @@ -227,7 +226,7 @@ def setup(self, stage: Optional[str] = None) -> None: ) ) else: - self.predict_dataset = InMemoryPredictionDataset( + self.predict_dataset = InMemoryPredDataset( prediction_config=self.prediction_config, inputs=self.pred_data, ) From c2e29ee30cb7507e52daa771e41bd7febc12220f Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 5 Jun 2024 16:32:34 +0200 Subject: [PATCH 07/14] (chore): remove unused parameters --- .../dataset/in_memory_pred_dataset.py | 20 +------------------ .../dataset/in_memory_tiled_pred_dataset.py | 18 ++--------------- 2 files changed, 3 insertions(+), 35 deletions(-) diff --git a/src/careamics/dataset/in_memory_pred_dataset.py b/src/careamics/dataset/in_memory_pred_dataset.py index 5d3bbcb17..254c2c824 100644 --- a/src/careamics/dataset/in_memory_pred_dataset.py +++ b/src/careamics/dataset/in_memory_pred_dataset.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Callable, Optional - import numpy as np from torch.utils.data import Dataset @@ -11,7 +9,7 @@ from ..config import InferenceConfig from ..config.transformations import NormalizeModel -from .dataset_utils import read_tiff, reshape_array +from .dataset_utils import reshape_array class InMemoryPredDataset(Dataset): @@ -23,18 +21,12 @@ class InMemoryPredDataset(Dataset): Prediction configuration. inputs : np.ndarray Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. """ def __init__( self, prediction_config: InferenceConfig, inputs: np.ndarray, - data_target: Optional[np.ndarray] = None, - read_source_func: Optional[Callable] = read_tiff, ) -> None: """Constructor. @@ -44,10 +36,6 @@ def __init__( Prediction configuration. inputs : np.ndarray Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. Raises ------ @@ -59,17 +47,11 @@ def __init__( self.axes = self.pred_config.axes self.tile_size = self.pred_config.tile_size self.tile_overlap = self.pred_config.tile_overlap - self.mean = self.pred_config.mean - self.std = self.pred_config.std - self.data_target = data_target self.mean, self.std = self.pred_config.mean, self.pred_config.std # tiling only if both tile size and overlap are provided self.tiling = self.tile_size is not None and self.tile_overlap is not None - # read function - self.read_source_func = read_source_func - # Reshape data self.data = reshape_array(self.input_array, self.axes) diff --git a/src/careamics/dataset/in_memory_tiled_pred_dataset.py b/src/careamics/dataset/in_memory_tiled_pred_dataset.py index 507feb8c1..3141657d7 100644 --- a/src/careamics/dataset/in_memory_tiled_pred_dataset.py +++ b/src/careamics/dataset/in_memory_tiled_pred_dataset.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Callable, List, Optional, Tuple +from typing import List, Tuple import numpy as np from torch.utils.data import Dataset @@ -12,7 +12,7 @@ from ..config import InferenceConfig from ..config.tile_information import TileInformation from ..config.transformations import NormalizeModel -from .dataset_utils import read_tiff, reshape_array +from .dataset_utils import reshape_array from .patching.tiled_patching import extract_tiles @@ -25,18 +25,12 @@ class InMemoryTiledPredictionDataset(Dataset): Prediction configuration. inputs : np.ndarray Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. """ def __init__( self, prediction_config: InferenceConfig, inputs: np.ndarray, - data_target: Optional[np.ndarray] = None, - read_source_func: Optional[Callable] = read_tiff, ) -> None: """Constructor. @@ -46,10 +40,6 @@ def __init__( Prediction configuration. inputs : np.ndarray Input data. - data_target : Optional[np.ndarray], optional - Target data, by default None. - read_source_func : Optional[Callable], optional - Read source function for custom types, by default read_tiff. Raises ------ @@ -72,10 +62,6 @@ def __init__( self.tile_overlap = prediction_config.tile_overlap self.mean = self.pred_config.mean self.std = self.pred_config.std - self.data_target = data_target - - # read function - self.read_source_func = read_source_func # Generate patches self.data = self._prepare_tiles() From d0584309379420947cb9e66509243949322650e8 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:09:49 +0200 Subject: [PATCH 08/14] (refac): rename classes --- src/careamics/dataset/__init__.py | 12 +++++----- .../dataset/in_memory_pred_dataset.py | 5 ---- .../dataset/in_memory_tiled_pred_dataset.py | 2 +- .../dataset/iterable_pred_dataset.py | 2 +- .../dataset/iterable_tiled_pred_dataset.py | 2 +- .../lightning_prediction_datamodule.py | 24 +++++++++---------- 6 files changed, 20 insertions(+), 27 deletions(-) diff --git a/src/careamics/dataset/__init__.py b/src/careamics/dataset/__init__.py index c29d1097d..43c39a3ef 100644 --- a/src/careamics/dataset/__init__.py +++ b/src/careamics/dataset/__init__.py @@ -3,15 +3,15 @@ __all__ = [ "InMemoryDataset", "InMemoryPredDataset", - "InMemoryTiledPredictionDataset", + "InMemoryTiledPredDataset", "PathIterableDataset", - "IterableTiledPredictionDataset", - "IterablePredictionDataset", + "IterableTiledPredDataset", + "IterablePredDataset", ] from .in_memory_dataset import InMemoryDataset from .in_memory_pred_dataset import InMemoryPredDataset -from .in_memory_tiled_pred_dataset import InMemoryTiledPredictionDataset +from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset from .iterable_dataset import PathIterableDataset -from .iterable_pred_dataset import IterablePredictionDataset -from .iterable_tiled_pred_dataset import IterableTiledPredictionDataset +from .iterable_pred_dataset import IterablePredDataset +from .iterable_tiled_pred_dataset import IterableTiledPredDataset diff --git a/src/careamics/dataset/in_memory_pred_dataset.py b/src/careamics/dataset/in_memory_pred_dataset.py index 254c2c824..2a0dc1ffd 100644 --- a/src/careamics/dataset/in_memory_pred_dataset.py +++ b/src/careamics/dataset/in_memory_pred_dataset.py @@ -45,13 +45,8 @@ def __init__( self.pred_config = prediction_config self.input_array = inputs self.axes = self.pred_config.axes - self.tile_size = self.pred_config.tile_size - self.tile_overlap = self.pred_config.tile_overlap self.mean, self.std = self.pred_config.mean, self.pred_config.std - # tiling only if both tile size and overlap are provided - self.tiling = self.tile_size is not None and self.tile_overlap is not None - # Reshape data self.data = reshape_array(self.input_array, self.axes) diff --git a/src/careamics/dataset/in_memory_tiled_pred_dataset.py b/src/careamics/dataset/in_memory_tiled_pred_dataset.py index 3141657d7..7f2a17d24 100644 --- a/src/careamics/dataset/in_memory_tiled_pred_dataset.py +++ b/src/careamics/dataset/in_memory_tiled_pred_dataset.py @@ -16,7 +16,7 @@ from .patching.tiled_patching import extract_tiles -class InMemoryTiledPredictionDataset(Dataset): +class InMemoryTiledPredDataset(Dataset): """Prediction dataset storing data in memory and returning tiles of each image. Parameters diff --git a/src/careamics/dataset/iterable_pred_dataset.py b/src/careamics/dataset/iterable_pred_dataset.py index d80e02c91..f676550a6 100644 --- a/src/careamics/dataset/iterable_pred_dataset.py +++ b/src/careamics/dataset/iterable_pred_dataset.py @@ -15,7 +15,7 @@ from .dataset_utils import iterate_over_files, read_tiff -class IterablePredictionDataset(IterableDataset): +class IterablePredDataset(IterableDataset): """Simple iterable prediction dataset. Parameters diff --git a/src/careamics/dataset/iterable_tiled_pred_dataset.py b/src/careamics/dataset/iterable_tiled_pred_dataset.py index a8f888b8a..0ccc901f8 100644 --- a/src/careamics/dataset/iterable_tiled_pred_dataset.py +++ b/src/careamics/dataset/iterable_tiled_pred_dataset.py @@ -17,7 +17,7 @@ from .patching.tiled_patching import extract_tiles -class IterableTiledPredictionDataset(IterableDataset): +class IterableTiledPredDataset(IterableDataset): """Tiled prediction dataset. Parameters diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index 4423d6294..1a32125fe 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -13,9 +13,9 @@ from careamics.config.tile_information import TileInformation from careamics.dataset import ( InMemoryPredDataset, - InMemoryTiledPredictionDataset, - IterablePredictionDataset, - IterableTiledPredictionDataset, + InMemoryTiledPredDataset, + IterablePredDataset, + IterableTiledPredDataset, ) from careamics.dataset.dataset_utils import ( get_read_func, @@ -25,9 +25,9 @@ PredictDatasetType = Union[ InMemoryPredDataset, - InMemoryTiledPredictionDataset, - IterablePredictionDataset, - IterableTiledPredictionDataset, + InMemoryTiledPredDataset, + IterablePredDataset, + IterableTiledPredDataset, ] logger = get_logger(__name__) @@ -219,11 +219,9 @@ def setup(self, stage: Optional[str] = None) -> None: # if numpy array if self.data_type == SupportedData.ARRAY: if self.tiled: - self.predict_dataset: PredictDatasetType = ( - InMemoryTiledPredictionDataset( - prediction_config=self.prediction_config, - inputs=self.pred_data, - ) + self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, ) else: self.predict_dataset = InMemoryPredDataset( @@ -232,13 +230,13 @@ def setup(self, stage: Optional[str] = None) -> None: ) else: if self.tiled: - self.predict_dataset = IterableTiledPredictionDataset( + self.predict_dataset = IterableTiledPredDataset( prediction_config=self.prediction_config, src_files=self.pred_files, read_source_func=self.read_source_func, ) else: - self.predict_dataset = IterablePredictionDataset( + self.predict_dataset = IterablePredDataset( prediction_config=self.prediction_config, src_files=self.pred_files, read_source_func=self.read_source_func, From 30adb7b1eaf71bbc5c4c9c68dad5cd48cb99c9d9 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:17:16 +0200 Subject: [PATCH 09/14] (chore): improve error message --- src/careamics/config/validators/validator_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/careamics/config/validators/validator_utils.py b/src/careamics/config/validators/validator_utils.py index da5eb0ae7..a8d88e782 100644 --- a/src/careamics/config/validators/validator_utils.py +++ b/src/careamics/config/validators/validator_utils.py @@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2( If the value is not a power of 2. """ if value < 8: - raise ValueError(f"Value must be non-zero positive (got {value}).") + raise ValueError(f"Value must be greater than 8 (got {value}).") if (value & (value - 1)) != 0: raise ValueError(f"Value must be a power of 2 (got {value}).") From 35d46b227766ecea43ae6b5c7c64a04d829efc3f Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:06:54 +0200 Subject: [PATCH 10/14] (refac): refactor collate and stitching --- src/careamics/careamist.py | 4 +- src/careamics/config/tile_information.py | 32 +++++++- src/careamics/dataset/in_memory_dataset.py | 2 +- src/careamics/dataset/iterable_dataset.py | 2 +- src/careamics/dataset/tiling/__init__.py | 11 +++ src/careamics/dataset/tiling/collate_tiles.py | 41 ++++++++++ .../dataset/tiling/stitch_prediction.py | 55 ++++++++++++++ .../{patching => tiling}/tiled_patching.py | 6 +- src/careamics/lightning_module.py | 4 +- .../lightning_prediction_datamodule.py | 42 +---------- src/careamics/lightning_prediction_loop.py | 44 ++++++----- src/careamics/prediction/__init__.py | 7 -- src/careamics/prediction/stitch_prediction.py | 70 ----------------- src/careamics/transforms/normalize.py | 17 ++--- tests/config/test_tile_information.py | 39 ++++++++++ .../dataset/prediction/test_collate_tiles.py | 75 +++++++++++++++++++ .../prediction/test_stitch_prediction.py | 51 +++++++++++++ .../test_tiled_patching.py | 2 +- tests/prediction/test_stitch_prediction.py | 41 ---------- tests/test_lightning_prediction_datamodule.py | 2 +- tests/transforms/test_normalize.py | 2 +- 21 files changed, 347 insertions(+), 202 deletions(-) create mode 100644 src/careamics/dataset/tiling/__init__.py create mode 100644 src/careamics/dataset/tiling/collate_tiles.py create mode 100644 src/careamics/dataset/tiling/stitch_prediction.py rename src/careamics/dataset/{patching => tiling}/tiled_patching.py (96%) delete mode 100644 src/careamics/prediction/__init__.py delete mode 100644 src/careamics/prediction/stitch_prediction.py create mode 100644 tests/dataset/prediction/test_collate_tiles.py create mode 100644 tests/dataset/prediction/test_stitch_prediction.py rename tests/dataset/{patching => prediction}/test_tiled_patching.py (98%) delete mode 100644 tests/prediction/test_stitch_prediction.py diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index aea217dd8..220dde228 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -700,7 +700,7 @@ def _create_data_for_bmz( denormalize = Denormalize( mean=self.cfg.data_config.mean, std=self.cfg.data_config.std ) - input_patch, _ = denormalize(input_patch) + input_patch = denormalize(input_patch) elif self.train_datamodule is not None: input_patch, *_ = next(iter(self.train_datamodule.train_dataloader())) @@ -710,7 +710,7 @@ def _create_data_for_bmz( denormalize = Denormalize( mean=self.cfg.data_config.mean, std=self.cfg.data_config.std ) - input_patch, _ = denormalize(input_patch) + input_patch = denormalize(input_patch) else: # create a random input array input_patch = np.random.normal( diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py index 3fc6a3468..6cac5eeaf 100644 --- a/src/careamics/config/tile_information.py +++ b/src/careamics/config/tile_information.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator class TileInformation(BaseModel): @@ -18,10 +18,10 @@ class TileInformation(BaseModel): model_config = ConfigDict(validate_default=True) array_shape: Tuple[int, ...] - tiled: bool = False + tiled: bool = False # TODO remove last_tile: bool = False - overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) - stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + overlap_crop_coords: Tuple[Tuple[int, ...], ...] + stitch_coords: Tuple[Tuple[int, ...], ...] @field_validator("array_shape") @classmethod @@ -104,3 +104,27 @@ def mandatory_if_tiled( return v else: return None + + def __eq__(self, other_tile: object): + """Check if two tile information objects are equal. + + Parameters + ---------- + other_tile : object + Tile information object to compare with. + + Returns + ------- + bool + Whether the two tile information objects are equal. + """ + if not isinstance(other_tile, TileInformation): + return NotImplemented + + return ( + self.array_shape == other_tile.array_shape + and self.tiled == other_tile.tiled + and self.last_tile == other_tile.last_tile + and self.overlap_crop_coords == other_tile.overlap_crop_coords + and self.stitch_coords == other_tile.stitch_coords + ) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 24ceb1e84..0747984d6 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -22,7 +22,7 @@ prepare_patches_unsupervised, prepare_patches_unsupervised_array, ) -from .patching.tiled_patching import extract_tiles +from .tiling.tiled_patching import extract_tiles logger = get_logger(__name__) diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 192e73864..7c53f8ec3 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -17,7 +17,7 @@ from ..utils.logging import get_logger from .dataset_utils import read_tiff, reshape_array from .patching.random_patching import extract_patches_random -from .patching.tiled_patching import extract_tiles +from .tiling.tiled_patching import extract_tiles logger = get_logger(__name__) diff --git a/src/careamics/dataset/tiling/__init__.py b/src/careamics/dataset/tiling/__init__.py new file mode 100644 index 000000000..f7b9643a1 --- /dev/null +++ b/src/careamics/dataset/tiling/__init__.py @@ -0,0 +1,11 @@ +"""Tiling functions.""" + +__all__ = [ + "stitch_prediction", + "extract_tiles", + "collate_tiles", +] + +from .collate_tiles import collate_tiles +from .stitch_prediction import stitch_prediction +from .tiled_patching import extract_tiles diff --git a/src/careamics/dataset/tiling/collate_tiles.py b/src/careamics/dataset/tiling/collate_tiles.py new file mode 100644 index 000000000..e8a10c518 --- /dev/null +++ b/src/careamics/dataset/tiling/collate_tiles.py @@ -0,0 +1,41 @@ +"""Collate function for tiling.""" + +from typing import Any, List, Tuple + +import numpy as np +from torch.utils.data.dataloader import default_collate + +from careamics.config.tile_information import TileInformation + + +def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: + """ + Collate tiles received from CAREamics prediction dataloader. + + CAREamics prediction dataloader returns tuples of arrays and TileInformation. In + case of non-tiled data, this function will return the arrays. In case of tiled data, + it will return the arrays, the last tile flag, the overlap crop coordinates and the + stitch coordinates. + + Parameters + ---------- + batch : List[Tuple[np.ndarray, TileInformation], ...] + Batch of tiles. + + Returns + ------- + Any + Collated batch. + """ + first_tile_info: TileInformation = batch[0][1] + # if not tiled, then return arrays + if not first_tile_info.tiled: + arrays, _ = zip(*batch) + + return default_collate(arrays) + # else we explicit the last_tile flag and coordinates + else: + new_batch = [tile for tile, _ in batch] + tiles_batch = [tile_info for _, tile_info in batch] + + return default_collate(new_batch), tiles_batch diff --git a/src/careamics/dataset/tiling/stitch_prediction.py b/src/careamics/dataset/tiling/stitch_prediction.py new file mode 100644 index 000000000..54f946042 --- /dev/null +++ b/src/careamics/dataset/tiling/stitch_prediction.py @@ -0,0 +1,55 @@ +"""Prediction utility functions.""" + +from typing import List + +import numpy as np + +from careamics.config.tile_information import TileInformation + + +def stitch_prediction( + tiles: List[np.ndarray], + tile_infos: List[TileInformation], +) -> np.ndarray: + """Stitch tiles back together to form a full image. + + Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a + singleton dimension. + + Parameters + ---------- + tiles : List[np.ndarray] + Cropped tiles and their respective stitching coordinates. + tile_infos : List[TileInformation] + List of information and coordinates obtained from + `dataset.tiled_patching.extract_tiles`. + + Returns + ------- + np.ndarray + Full image. + """ + # retrieve whole array size + input_shape = tile_infos[0].array_shape + predicted_image = np.zeros(input_shape, dtype=np.float32) + + for tile, tile_info in zip(tiles, tile_infos): + n_channels = tile.shape[0] + + # Compute coordinates for cropping predicted tile + slices = (slice(0, n_channels),) + tuple( + [slice(c[0], c[1]) for c in tile_info.overlap_crop_coords] + ) + + # Crop predited tile according to overlap coordinates + cropped_tile = tile[slices] + + # Insert cropped tile into predicted image using stitch coordinates + predicted_image[ + ( + ..., + *[slice(c[0], c[1]) for c in tile_info.stitch_coords], + ) + ] = cropped_tile.astype(np.float32) + + return predicted_image diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/tiling/tiled_patching.py similarity index 96% rename from src/careamics/dataset/patching/tiled_patching.py rename to src/careamics/dataset/tiling/tiled_patching.py index 890c7f616..83b61e342 100644 --- a/src/careamics/dataset/patching/tiled_patching.py +++ b/src/careamics/dataset/tiling/tiled_patching.py @@ -84,15 +84,15 @@ def extract_tiles( tile_size: Union[List[int], Tuple[int, ...]], overlaps: Union[List[int], Tuple[int, ...]], ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: - """ - Generate tiles from the input array with specified overlap. + """Generate tiles from the input array with specified overlap. The tiles cover the whole array. The method returns a generator that yields tuples of array and tile information, the latter includes whether the tile is the last one, the coordinates of the overlap crop, and the coordinates of the stitched tile. - The array has shape C(Z)YX, where C can be a singleton. + Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX, + where C can be a singleton. Parameters ---------- diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index e16ea39e4..57848c582 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -168,9 +168,9 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: mean=self._trainer.datamodule.predict_dataset.mean, std=self._trainer.datamodule.predict_dataset.std, ) - denormalized_output, _ = denorm(patch=output) + denormalized_output = denorm(patch=output.cpu()) - if len(aux) > 0: + if len(aux) > 0: # aux can be tiling information return denormalized_output, aux else: return denormalized_output diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index b7a730faa..f6290c8b3 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -1,16 +1,14 @@ """Prediction Lightning data modules.""" from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union import numpy as np import pytorch_lightning as L from torch.utils.data import DataLoader -from torch.utils.data.dataloader import default_collate from careamics.config import InferenceConfig from careamics.config.support import SupportedData -from careamics.config.tile_information import TileInformation from careamics.dataset.dataset_utils import ( get_read_func, list_files, @@ -21,6 +19,7 @@ from careamics.dataset.iterable_dataset import ( IterablePredictionDataset, ) +from careamics.dataset.tiling.collate_tiles import collate_tiles from careamics.utils import get_logger PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] @@ -28,41 +27,6 @@ logger = get_logger(__name__) -def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: - """ - Collate tiles received from CAREamics prediction dataloader. - - CAREamics prediction dataloader returns tuples of arrays and TileInformation. In - case of non-tiled data, this function will return the arrays. In case of tiled data, - it will return the arrays, the last tile flag, the overlap crop coordinates and the - stitch coordinates. - - Parameters - ---------- - batch : List[Tuple[np.ndarray, TileInformation], ...] - Batch of tiles. - - Returns - ------- - Any - Collated batch. - """ - first_tile_info: TileInformation = batch[0][1] - # if not tiled, then return arrays - if not first_tile_info.tiled: - arrays, _ = zip(*batch) - - return default_collate(arrays) - # else we explicit the last_tile flag and coordinates - else: - new_batch = [ - (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords) - for tile, t in batch - ] - - return default_collate(new_batch) - - class CAREamicsPredictData(L.LightningDataModule): """ CAREamics Lightning prediction data module. @@ -236,7 +200,7 @@ def predict_dataloader(self) -> DataLoader: return DataLoader( self.predict_dataset, batch_size=self.batch_size, - collate_fn=_collate_tiles, + collate_fn=collate_tiles, **self.dataloader_params, ) # TODO check workers are used diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py index ab44a17a8..9298a780e 100644 --- a/src/careamics/lightning_prediction_loop.py +++ b/src/careamics/lightning_prediction_loop.py @@ -1,14 +1,16 @@ """Lithning prediction loop allowing tiling.""" -from typing import Optional +from typing import List, Optional +import numpy as np import pytorch_lightning as L from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher from pytorch_lightning.loops.utilities import _no_grad_context from pytorch_lightning.trainer import call from pytorch_lightning.utilities.types import _PREDICT_OUTPUT -from careamics.prediction import stitch_prediction +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling import stitch_prediction class CAREamicsPredictionLoop(L.loops._PredictionLoop): @@ -37,11 +39,10 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: ######################################################## ################ CAREamics specific code ############### if len(self.predicted_array) == 1: - # TODO does this make sense to here? (force numpy array) - return self.predicted_array[0].numpy() + # single array, already a numpy array + return self.predicted_array[0] # todo why not return the list here? else: - # TODO revisit logic - return [element.numpy() for element in self.predicted_array] + return self.predicted_array ######################################################## return None @@ -65,8 +66,8 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: assert data_fetcher is not None self.predicted_array = [] - self.tiles = [] - self.stitching_data = [] + self.tiles: List[np.ndarray] = [] + self.tile_information: List[TileInformation] = [] while True: try: @@ -87,27 +88,34 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: ######################################################## ################ CAREamics specific code ############### - # TODO: next line is not compatible with muSplit is_tiled = len(self.predictions[batch_idx]) == 2 if is_tiled: - # extract the last tile flag and the coordinates (crop and stitch) - last_tile, *stitch_data = self.predictions[batch_idx][1] + # a numpy array of shape BC(Z)YX + tile_batch = self.predictions[batch_idx][0] - # append the tile and the coordinates to the lists - self.tiles.append(self.predictions[batch_idx][0]) - self.stitching_data.append(stitch_data) + # split the tiles into C(Z)YX (skip singleton S) and + # add them to the tiles list + self.tiles.extend( + np.split(tile_batch.numpy(), tile_batch.shape[0], axis=0)[0] + ) + + # tile information is passed as a list of list of TileInformation + # TODO why list of list? + tile_info = self.predictions[batch_idx][1][0] + self.tile_information.extend(tile_info) # if last tile, stitch the tiles and add array to the prediction - if any(last_tile): + last_tiles = [t.last_tile for t in self.tile_information] + if any(last_tiles): predicted_batches = stitch_prediction( - self.tiles, self.stitching_data + self.tiles, self.tile_information ) self.predicted_array.append(predicted_batches) self.tiles.clear() - self.stitching_data.clear() + self.tile_information.clear() else: # simply add the prediction to the list - self.predicted_array.append(self.predictions[batch_idx]) + self.predicted_array.append(self.predictions[batch_idx].numpy()) ######################################################## except StopIteration: break diff --git a/src/careamics/prediction/__init__.py b/src/careamics/prediction/__init__.py deleted file mode 100644 index 852e65de1..000000000 --- a/src/careamics/prediction/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Prediction functions.""" - -__all__ = [ - "stitch_prediction", -] - -from .stitch_prediction import stitch_prediction diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py deleted file mode 100644 index 1b4fe9690..000000000 --- a/src/careamics/prediction/stitch_prediction.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Prediction utility functions.""" - -from typing import List - -import numpy as np -import torch - - -def stitch_prediction( - tiles: List[torch.Tensor], - stitching_data: List[List[torch.Tensor]], -) -> torch.Tensor: - """ - Stitch tiles back together to form a full image. - - Parameters - ---------- - tiles : List[torch.Tensor] - Cropped tiles and their respective stitching coordinates. - stitching_data : List - List of information and coordinates obtained from - `dataset.tiled_patching.extract_tiles`. - - Returns - ------- - np.ndarray - Full image. - """ - # retrieve whole array size, there is two cases to consider: - # 1. the tiles are stored in a list - # 2. the tiles are stored in a list with batches along the first dim - if tiles[0].shape[0] > 1: - input_shape = np.array( - [el.numpy() for el in stitching_data[0][0][0]], dtype=int - ).squeeze() - else: - input_shape = np.array( - [el.numpy() for el in stitching_data[0][0]], dtype=int - ).squeeze() - - # TODO should use torch.zeros instead of np.zeros - predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32)) - - for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( - tiles, stitching_data - ): - for batch_idx in range(tile_batch.shape[0]): - # Compute coordinates for cropping predicted tile - slices = tuple( - [ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in overlap_crop_coords_batch - ] - ) - - # Crop predited tile according to overlap coordinates - cropped_tile = tile_batch[batch_idx].squeeze()[slices] - - # Insert cropped tile into predicted image using stitch coordinates - predicted_image[ - ( - ..., - *[ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in stitch_coords_batch - ], - ) - ] = cropped_tile.to(torch.float32) - - return predicted_image diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py index 1e24afd5b..10a1d7423 100644 --- a/src/careamics/transforms/normalize.py +++ b/src/careamics/transforms/normalize.py @@ -91,12 +91,12 @@ def _apply(self, patch: np.ndarray) -> np.ndarray: class Denormalize: """ - Denormalize an image or image patch. + Denormalize an image. Denormalization is performed expecting a zero mean and unit variance input. This transform expects C(Z)YX dimensions. - Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + Note that an epsilon value of 1e-6 is added to the standard deviation to avoid division by zero during the normalization step, which is taken into account during denormalization. @@ -133,27 +133,22 @@ def __init__( self.std = std self.eps = 1e-6 - def __call__( - self, patch: np.ndarray, target: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + def __call__(self, patch: np.ndarray) -> np.ndarray: """Apply the transform to the source patch and the target (optional). Parameters ---------- patch : np.ndarray Patch, 2D or 3D, shape C(Z)YX. - target : Optional[np.ndarray], optional - Target for the patch, by default None. Returns ------- - Tuple[np.ndarray, Optional[np.ndarray]] - Transformed patch and target. + np.ndarray + Transformed patch. """ norm_patch = self._apply(patch) - norm_target = self._apply(target) if target is not None else None - return norm_patch, norm_target + return norm_patch def _apply(self, patch: np.ndarray) -> np.ndarray: """ diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py index 78b24cc80..218a8caed 100644 --- a/tests/config/test_tile_information.py +++ b/tests/config/test_tile_information.py @@ -48,3 +48,42 @@ def test_error_on_singleton_dims(): """Test that an error is raised if the array shape contains singleton dimensions.""" with pytest.raises(ValueError): TileInformation(array_shape=(2, 1, 6, 6)) + + +def test_tile_equality(): + """Test whether two tile information objects are equal.""" + t1 = TileInformation( + array_shape=(6, 6), + tiled=True, + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + t2 = TileInformation( + array_shape=(6, 6), + tiled=True, + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + assert t1 == t2 + + # inequality + t2.array_shape = (7, 7) + assert t1 != t2 + + t2.array_shape = (6, 6) + t2.tiled = False + assert t1 != t2 + + t2.tiled = True + t2.last_tile = False + assert t1 != t2 + + t2.last_tile = True + t2.overlap_crop_coords = ((2, 3),) + assert t1 != t2 + + t2.overlap_crop_coords = ((1, 2),) + t2.stitch_coords = ((4, 5),) + assert t1 != t2 diff --git a/tests/dataset/prediction/test_collate_tiles.py b/tests/dataset/prediction/test_collate_tiles.py new file mode 100644 index 000000000..483391fed --- /dev/null +++ b/tests/dataset/prediction/test_collate_tiles.py @@ -0,0 +1,75 @@ +import pytest + +from careamics.dataset.tiling import collate_tiles, extract_tiles + + +@pytest.mark.parametrize("n_channels", [1, 4]) +@pytest.mark.parametrize("batch", [1, 3]) +def test_collate_tiles_2d(ordered_array, n_channels, batch): + """Test that the collate tiles function collates tile information correctly.""" + tile_size = (4, 4) + tile_overlap = (2, 2) + shape = (1, n_channels, 8, 8) + + # create array + array = ordered_array(shape) + + # extract tiles + tiles = list(extract_tiles(array, tile_size=tile_size, overlaps=tile_overlap)) + + tiles_used = 0 + n_tiles = len(tiles) + while tiles_used < n_tiles: + # get a batch of tiles + batch_tiles = tiles[tiles_used : tiles_used + batch] + tiles_used += batch + + # collate the tiles + collated_tiles = collate_tiles(batch_tiles) + + # check the collated tiles + assert collated_tiles[0].shape == (batch, n_channels) + tile_size + + # check the tile info + tile_infos = collated_tiles[1] + assert len(tile_infos) == batch + + for i, t in enumerate(tile_infos): + for j in range(i + 1, len(tile_infos)): + assert t != tile_infos[j] + + +@pytest.mark.parametrize("n_channels", [1, 4]) +@pytest.mark.parametrize("batch", [1, 3]) +def test_collate_tiles_3d(ordered_array, n_channels, batch): + """Test that the collate tiles function collates tile information correctly.""" + tile_size = (4, 4, 4) + tile_overlap = (2, 2, 2) + shape = (1, n_channels, 8, 8, 8) + + # create array + array = ordered_array(shape) + + # extract tiles + tiles = list(extract_tiles(array, tile_size=tile_size, overlaps=tile_overlap)) + + tiles_used = 0 + n_tiles = len(tiles) + while tiles_used < n_tiles: + # get a batch of tiles + batch_tiles = tiles[tiles_used : tiles_used + batch] + tiles_used += batch + + # collate the tiles + collated_tiles = collate_tiles(batch_tiles) + + # check the collated tiles + assert collated_tiles[0].shape == (batch, n_channels) + tile_size + + # check the tile info + tile_infos = collated_tiles[1] + assert len(tile_infos) == batch + + for i, t in enumerate(tile_infos): + for j in range(i + 1, len(tile_infos)): + assert t != tile_infos[j] diff --git a/tests/dataset/prediction/test_stitch_prediction.py b/tests/dataset/prediction/test_stitch_prediction.py new file mode 100644 index 000000000..062f48ec3 --- /dev/null +++ b/tests/dataset/prediction/test_stitch_prediction.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest + +from careamics.dataset.tiling import extract_tiles, stitch_prediction + + +@pytest.mark.parametrize( + "input_shape, tile_size, overlaps", + [ + ((1, 1, 8, 8), (4, 4), (2, 2)), + ((1, 2, 8, 8), (4, 4), (2, 2)), + ((2, 1, 8, 8), (4, 4), (2, 2)), + ((2, 2, 8, 8), (4, 4), (2, 2)), + ((1, 1, 7, 9), (4, 4), (2, 2)), + ((1, 3, 7, 9), (4, 4), (2, 2)), + ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), + ((1, 1, 321, 481), (256, 256), (48, 48)), + ((2, 1, 321, 481), (256, 256), (48, 48)), + ((1, 4, 321, 481), (256, 256), (48, 48)), + ((4, 3, 321, 481), (256, 256), (48, 48)), + ], +) +def test_stitch_tiles(ordered_array, input_shape, tile_size, overlaps): + """Test stitching tiles back together.""" + arr = ordered_array(input_shape, dtype=int) + n_samples = input_shape[0] + + # extract tiles + all_tiles = list(extract_tiles(arr, tile_size, overlaps)) + + tiles = [] + tile_infos = [] + sample_id = 0 + for tile, tile_info in all_tiles: + # create lists mimicking the output of the prediction loop + tiles.append(tile) + tile_infos.append(tile_info) + + # if we reached the last tile + if tile_info.last_tile: + result = stitch_prediction(tiles, tile_infos) + + # check equality with the correct sample + assert np.array_equal(result, arr[sample_id].squeeze()) + sample_id += 1 + + # clear the lists + tiles.clear() + tile_infos.clear() + + assert sample_id == n_samples diff --git a/tests/dataset/patching/test_tiled_patching.py b/tests/dataset/prediction/test_tiled_patching.py similarity index 98% rename from tests/dataset/patching/test_tiled_patching.py rename to tests/dataset/prediction/test_tiled_patching.py index a7e135d4f..920fa907a 100644 --- a/tests/dataset/patching/test_tiled_patching.py +++ b/tests/dataset/prediction/test_tiled_patching.py @@ -2,7 +2,7 @@ import pytest from careamics.config.tile_information import TileInformation -from careamics.dataset.patching.tiled_patching import ( +from careamics.dataset.tiling.tiled_patching import ( _compute_crop_and_stitch_coords_1d, extract_tiles, ) diff --git a/tests/prediction/test_stitch_prediction.py b/tests/prediction/test_stitch_prediction.py deleted file mode 100644 index 4908af233..000000000 --- a/tests/prediction/test_stitch_prediction.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from torch import from_numpy, tensor - -from careamics.dataset.patching.tiled_patching import extract_tiles -from careamics.prediction.stitch_prediction import stitch_prediction - - -@pytest.mark.parametrize( - "input_shape, tile_size, overlaps", - [ - ((1, 1, 8, 8), (4, 4), (2, 2)), - ((1, 1, 8, 8), (4, 4), (2, 2)), - ((1, 1, 7, 9), (4, 4), (2, 2)), - ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), - ((1, 1, 321, 481), (256, 256), (48, 48)), - ], -) -def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps): - """Test calculating stitching coordinates.""" - arr = ordered_array(input_shape, dtype=int) - tiles = [] - stitching_data = [] - - # extract tiles - tile_generator = extract_tiles(arr, tile_size, overlaps) - - # Assemble all tiles as it is done during the prediction stage - for tile_data, tile_info in tile_generator: - tiles.append(from_numpy(tile_data)) # need to convert to torch.Tensor - stitching_data.append( - ( # this is way too wacky - [tensor(i) for i in input_shape], # need to convert to torch.Tensor - [[tensor([j]) for j in i] for i in tile_info.overlap_crop_coords], - [[tensor([j]) for j in i] for i in tile_info.stitch_coords], - ) - ) - - # compute stitching coordinates, it returns a torch.Tensor - result = stitch_prediction(tiles, stitching_data) - - assert (result.numpy() == arr).all() diff --git a/tests/test_lightning_prediction_datamodule.py b/tests/test_lightning_prediction_datamodule.py index cf9e4881d..03edcd324 100644 --- a/tests/test_lightning_prediction_datamodule.py +++ b/tests/test_lightning_prediction_datamodule.py @@ -63,7 +63,7 @@ def test_wrapper_instantiated_with_tiling(simple_array): assert len(list(data_module.predict_dataloader())) == 2 -def test_lwrapper_instantiated_without_tiling(simple_array): +def test_wrapper_instantiated_without_tiling(simple_array): """Test that the data module is created correctly with an array.""" # create data module data_module = PredictDataWrapper( diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py index ffdb4bb6f..057444c7c 100644 --- a/tests/transforms/test_normalize.py +++ b/tests/transforms/test_normalize.py @@ -26,5 +26,5 @@ def test_normalize_denormalize(): ) # Apply the denormalize transform - denormalized, _ = denorm(patch=normalized) + denormalized = denorm(patch=normalized) assert np.isclose(denormalized, array).all() From 8a5c003ca559e373d8dfb32b15cb57ab823ed8ca Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:16:00 +0200 Subject: [PATCH 11/14] (chore): add tests for #125 --- .../patching/validate_patch_dimension.py | 8 ++-- tests/config/test_tile_information.py | 15 ++++++-- tests/test_careamist.py | 38 +++++++++++++++++++ 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/careamics/dataset/patching/validate_patch_dimension.py b/src/careamics/dataset/patching/validate_patch_dimension.py index 8174493a4..56fd6d698 100644 --- a/src/careamics/dataset/patching/validate_patch_dimension.py +++ b/src/careamics/dataset/patching/validate_patch_dimension.py @@ -45,18 +45,20 @@ def validate_patch_dimensions( if len(patch_size) != len(arr.shape[2:]): raise ValueError( f"There must be a patch size for each spatial dimensions " - f"(got {patch_size} patches for dims {arr.shape})." + f"(got {patch_size} patches for dims {arr.shape}). Check the axes order." ) # Sanity checks on patch sizes versus array dimension if is_3d_patch and patch_size[0] > arr.shape[-3]: raise ValueError( f"Z patch size is inconsistent with image shape " - f"(got {patch_size[0]} patches for dim {arr.shape[1]})." + f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes " + f"order." ) if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]: raise ValueError( f"At least one of YX patch dimensions is larger than the corresponding " - f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})." + f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). " + f"Check the axes order." ) diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py index 218a8caed..17558dac9 100644 --- a/tests/config/test_tile_information.py +++ b/tests/config/test_tile_information.py @@ -6,13 +6,15 @@ def test_defaults(): """Test instantiating time information with defaults.""" - tile_info = TileInformation(array_shape=np.zeros((6, 6)).shape) + tile_info = TileInformation( + array_shape=np.zeros((6, 6)).shape, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) assert tile_info.array_shape == (6, 6) assert not tile_info.tiled assert not tile_info.last_tile - assert tile_info.overlap_crop_coords is None - assert tile_info.stitch_coords is None def test_tiled(): @@ -34,7 +36,12 @@ def test_tiled(): def test_validation_last_tile(): """Test that last tile is only set if tiled is set.""" - tile_info = TileInformation(array_shape=(6, 6), last_tile=True) + tile_info = TileInformation( + array_shape=(6, 6), + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) assert not tile_info.last_tile diff --git a/tests/test_careamist.py b/tests/test_careamist.py index c2fafa5d5..9d7fb6af5 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -563,6 +563,44 @@ def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() +@pytest.mark.parametrize("independent_channels", [False, True]) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_predict_tiled_channel( + tmp_path: Path, + minimum_configuration: dict, + independent_channels: bool, + batch_size: int, +): + """Test that CAREamics can be trained on arrays with channels.""" + # training data + train_array = random_array((3, 32, 32)) + val_array = random_array((3, 32, 32)) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "CYX" + config.algorithm_config.model.in_channels = 3 + config.algorithm_config.model.num_classes = 3 + config.algorithm_config.model.independent_channels = independent_channels + config.data_config.batch_size = batch_size + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # predict CAREamist + predicted = careamist.predict( + train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4) + ) + + assert predicted.squeeze().shape == train_array.shape + + @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): """Test that CAREamics can predict with tiff files.""" From 572f9f3cb82a305ff188dae8dcd6b2872e29a0f2 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:02:43 +0200 Subject: [PATCH 12/14] (refac): remove unused tile parameter --- src/careamics/config/tile_information.py | 66 ++----------------- src/careamics/dataset/tiling/collate_tiles.py | 16 ++--- .../dataset/tiling/tiled_patching.py | 1 - tests/config/test_tile_information.py | 19 ++---- 4 files changed, 16 insertions(+), 86 deletions(-) diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py index 6cac5eeaf..4d955339a 100644 --- a/src/careamics/config/tile_information.py +++ b/src/careamics/config/tile_information.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple -from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator +from pydantic import BaseModel, ConfigDict, field_validator class TileInformation(BaseModel): @@ -13,12 +13,14 @@ class TileInformation(BaseModel): This model is used to represent the information required to stitch back a tile into a larger image. It is used throughout the prediction pipeline of CAREamics. + + Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not + contain singleton dimensions. """ model_config = ConfigDict(validate_default=True) array_shape: Tuple[int, ...] - tiled: bool = False # TODO remove last_tile: bool = False overlap_crop_coords: Tuple[Tuple[int, ...], ...] stitch_coords: Tuple[Tuple[int, ...], ...] @@ -48,63 +50,6 @@ def no_singleton_dimensions(cls, v: Tuple[int, ...]): raise ValueError("Array shape must not contain singleton dimensions.") return v - @field_validator("last_tile") - @classmethod - def only_if_tiled(cls, v: bool, values: ValidationInfo): - """ - Check that the last tile flag is only set if tiling is enabled. - - Parameters - ---------- - v : bool - Last tile flag. - values : ValidationInfo - Validation information. - - Returns - ------- - bool - The last tile flag. - """ - if not values.data["tiled"]: - return False - return v - - @field_validator("overlap_crop_coords", "stitch_coords") - @classmethod - def mandatory_if_tiled( - cls, v: Optional[Tuple[int, ...]], values: ValidationInfo - ) -> Optional[Tuple[int, ...]]: - """ - Check that the coordinates are not `None` if tiling is enabled. - - The method also return `None` if tiling is not enabled. - - Parameters - ---------- - v : Optional[Tuple[int, ...]] - Coordinates to check. - values : ValidationInfo - Validation information. - - Returns - ------- - Optional[Tuple[int, ...]] - The coordinates if tiling is enabled, otherwise `None`. - - Raises - ------ - ValueError - If the coordinates are `None` and tiling is enabled. - """ - if values.data["tiled"]: - if v is None: - raise ValueError("Value must be specified if tiling is enabled.") - - return v - else: - return None - def __eq__(self, other_tile: object): """Check if two tile information objects are equal. @@ -123,7 +68,6 @@ def __eq__(self, other_tile: object): return ( self.array_shape == other_tile.array_shape - and self.tiled == other_tile.tiled and self.last_tile == other_tile.last_tile and self.overlap_crop_coords == other_tile.overlap_crop_coords and self.stitch_coords == other_tile.stitch_coords diff --git a/src/careamics/dataset/tiling/collate_tiles.py b/src/careamics/dataset/tiling/collate_tiles.py index e8a10c518..ceefc601f 100644 --- a/src/careamics/dataset/tiling/collate_tiles.py +++ b/src/careamics/dataset/tiling/collate_tiles.py @@ -27,15 +27,7 @@ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: Any Collated batch. """ - first_tile_info: TileInformation = batch[0][1] - # if not tiled, then return arrays - if not first_tile_info.tiled: - arrays, _ = zip(*batch) - - return default_collate(arrays) - # else we explicit the last_tile flag and coordinates - else: - new_batch = [tile for tile, _ in batch] - tiles_batch = [tile_info for _, tile_info in batch] - - return default_collate(new_batch), tiles_batch + new_batch = [tile for tile, _ in batch] + tiles_batch = [tile_info for _, tile_info in batch] + + return default_collate(new_batch), tiles_batch diff --git a/src/careamics/dataset/tiling/tiled_patching.py b/src/careamics/dataset/tiling/tiled_patching.py index 83b61e342..10fe695cd 100644 --- a/src/careamics/dataset/tiling/tiled_patching.py +++ b/src/careamics/dataset/tiling/tiled_patching.py @@ -155,7 +155,6 @@ def extract_tiles( # create tile information tile_info = TileInformation( array_shape=sample.squeeze().shape, - tiled=True, last_tile=last_tile, overlap_crop_coords=overlap_crop_coords, stitch_coords=stitch_coords, diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py index 17558dac9..6b445fcb9 100644 --- a/tests/config/test_tile_information.py +++ b/tests/config/test_tile_information.py @@ -13,7 +13,6 @@ def test_defaults(): ) assert tile_info.array_shape == (6, 6) - assert not tile_info.tiled assert not tile_info.last_tile @@ -21,14 +20,12 @@ def test_tiled(): """Test instantiating time information with parameters.""" tile_info = TileInformation( array_shape=np.zeros((6, 6)).shape, - tiled=True, last_tile=True, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), ) assert tile_info.array_shape == (6, 6) - assert tile_info.tiled assert tile_info.last_tile assert tile_info.overlap_crop_coords == ((1, 2),) assert tile_info.stitch_coords == ((3, 4),) @@ -46,29 +43,31 @@ def test_validation_last_tile(): def test_error_on_coords(): - """Test than an error is raised if it is tiled but not coordinates are given.""" + """Test than an error is raised if no coordinates are given.""" with pytest.raises(ValueError): - TileInformation(array_shape=(6, 6), tiled=True) + TileInformation(array_shape=(6, 6)) def test_error_on_singleton_dims(): """Test that an error is raised if the array shape contains singleton dimensions.""" with pytest.raises(ValueError): - TileInformation(array_shape=(2, 1, 6, 6)) + TileInformation( + array_shape=(2, 1, 6, 6), + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) def test_tile_equality(): """Test whether two tile information objects are equal.""" t1 = TileInformation( array_shape=(6, 6), - tiled=True, last_tile=True, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), ) t2 = TileInformation( array_shape=(6, 6), - tiled=True, last_tile=True, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), @@ -80,10 +79,6 @@ def test_tile_equality(): assert t1 != t2 t2.array_shape = (6, 6) - t2.tiled = False - assert t1 != t2 - - t2.tiled = True t2.last_tile = False assert t1 != t2 From 6e3c9900604bfda00ef6f819d13a7718f04a9752 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:08:04 +0200 Subject: [PATCH 13/14] (chore): add test for in memory pred dataset and remove useless tests --- tests/config/test_tile_information.py | 26 -------- tests/dataset/test_in_memory_pred_dataset.py | 64 +++++++++++++++++++ .../test_in_memory_tiled_pred_dataset.py | 0 tests/dataset/test_iterable_pred_dataset.py | 0 .../test_iterable_tiled_pred_dataset.py | 0 5 files changed, 64 insertions(+), 26 deletions(-) create mode 100644 tests/dataset/test_in_memory_pred_dataset.py create mode 100644 tests/dataset/test_in_memory_tiled_pred_dataset.py create mode 100644 tests/dataset/test_iterable_pred_dataset.py create mode 100644 tests/dataset/test_iterable_tiled_pred_dataset.py diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py index 6b445fcb9..1f10b24b5 100644 --- a/tests/config/test_tile_information.py +++ b/tests/config/test_tile_information.py @@ -16,32 +16,6 @@ def test_defaults(): assert not tile_info.last_tile -def test_tiled(): - """Test instantiating time information with parameters.""" - tile_info = TileInformation( - array_shape=np.zeros((6, 6)).shape, - last_tile=True, - overlap_crop_coords=((1, 2),), - stitch_coords=((3, 4),), - ) - - assert tile_info.array_shape == (6, 6) - assert tile_info.last_tile - assert tile_info.overlap_crop_coords == ((1, 2),) - assert tile_info.stitch_coords == ((3, 4),) - - -def test_validation_last_tile(): - """Test that last tile is only set if tiled is set.""" - tile_info = TileInformation( - array_shape=(6, 6), - last_tile=True, - overlap_crop_coords=((1, 2),), - stitch_coords=((3, 4),), - ) - assert not tile_info.last_tile - - def test_error_on_coords(): """Test than an error is raised if no coordinates are given.""" with pytest.raises(ValueError): diff --git a/tests/dataset/test_in_memory_pred_dataset.py b/tests/dataset/test_in_memory_pred_dataset.py new file mode 100644 index 000000000..2f737b26b --- /dev/null +++ b/tests/dataset/test_in_memory_pred_dataset.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest + +from careamics.config import InferenceConfig +from careamics.dataset import InMemoryPredDataset + + +@pytest.mark.parametrize( + "shape, axes, expected_shape", + [ + ((16, 16), "YX", (1, 1, 16, 16)), + ((3, 16, 16), "CYX", (1, 3, 16, 16)), + ((8, 16, 16), "ZYX", (1, 1, 8, 16, 16)), + ((3, 8, 16, 16), "CZYX", (1, 3, 8, 16, 16)), + ((4, 16, 16), "SYX", (1, 1, 16, 16)), + ((4, 3, 16, 16), "SCYX", (1, 3, 16, 16)), + ((4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), + ], +) +def test_in_memory_pred_dataset(shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + # check expected length + if "S" in axes: + # find index of S and check shape + idx = axes.index("S") + n_patches = shape[idx] + else: + n_patches = 1 + + # create array + array = 255 * rng.random(shape) + + # create config + config = InferenceConfig( + data_type="array", + axes=axes, + mean=np.mean(array), + std=np.std(array), + ) + + # create dataset + dataset = InMemoryPredDataset(config, array) + + # check length + assert len(dataset) == n_patches + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img = dataset[i] + + # check that it has the correct shape + assert img.shape == expected_shape + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.1) + assert np.isclose(np.std(img), 1, atol=0.1) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2 = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/dataset/test_in_memory_tiled_pred_dataset.py b/tests/dataset/test_in_memory_tiled_pred_dataset.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/dataset/test_iterable_pred_dataset.py b/tests/dataset/test_iterable_pred_dataset.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/dataset/test_iterable_tiled_pred_dataset.py b/tests/dataset/test_iterable_tiled_pred_dataset.py new file mode 100644 index 000000000..e69de29bb From 616fac0dc968d0bd07a6e4f88501d775fc712cf5 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 7 Jun 2024 19:10:07 +0200 Subject: [PATCH 14/14] (chore): add prediction datasets tests --- .../dataset/dataset_utils/read_tiff.py | 9 -- .../dataset/iterable_pred_dataset.py | 8 +- .../dataset/iterable_tiled_pred_dataset.py | 2 +- tests/dataset/test_in_memory_pred_dataset.py | 2 +- .../test_in_memory_tiled_pred_dataset.py | 86 +++++++++++++++ tests/dataset/test_iterable_pred_dataset.py | 82 ++++++++++++++ .../test_iterable_tiled_pred_dataset.py | 104 ++++++++++++++++++ 7 files changed, 279 insertions(+), 14 deletions(-) diff --git a/src/careamics/dataset/dataset_utils/read_tiff.py b/src/careamics/dataset/dataset_utils/read_tiff.py index ab557f2f9..0cea0f695 100644 --- a/src/careamics/dataset/dataset_utils/read_tiff.py +++ b/src/careamics/dataset/dataset_utils/read_tiff.py @@ -53,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray: else: raise ValueError(f"File {file_path} is not a valid tiff.") - # check dimensions - # TODO or should this really be done here? probably in the LightningDataModule - # TODO this should also be centralized somewhere else (validate_dimensions) - if len(array.shape) < 2 or len(array.shape) > 6: - raise ValueError( - f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for" - f"file {file_path})." - ) - return array diff --git a/src/careamics/dataset/iterable_pred_dataset.py b/src/careamics/dataset/iterable_pred_dataset.py index f676550a6..792dde1d2 100644 --- a/src/careamics/dataset/iterable_pred_dataset.py +++ b/src/careamics/dataset/iterable_pred_dataset.py @@ -109,7 +109,9 @@ def __iter__( self.data_files, read_source_func=self.read_source_func, ): - # TODO what if S dimensions > 1, should we yield each sample independently? - transformed_sample, _ = self.patch_transform(patch=sample) + # sample has S dimension + for i in range(sample.shape[0]): - yield transformed_sample + transformed_sample, _ = self.patch_transform(patch=sample[[i]]) + + yield transformed_sample diff --git a/src/careamics/dataset/iterable_tiled_pred_dataset.py b/src/careamics/dataset/iterable_tiled_pred_dataset.py index bedc3b45b..fa2783f49 100644 --- a/src/careamics/dataset/iterable_tiled_pred_dataset.py +++ b/src/careamics/dataset/iterable_tiled_pred_dataset.py @@ -121,7 +121,7 @@ def __iter__( self.data_files, read_source_func=self.read_source_func, ): - # generate patches, return a generator + # generate patches, return a generator of single tiles patch_gen = extract_tiles( arr=sample, tile_size=self.tile_size, diff --git a/tests/dataset/test_in_memory_pred_dataset.py b/tests/dataset/test_in_memory_pred_dataset.py index 2f737b26b..46099cbe5 100644 --- a/tests/dataset/test_in_memory_pred_dataset.py +++ b/tests/dataset/test_in_memory_pred_dataset.py @@ -17,7 +17,7 @@ ((4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), ], ) -def test_in_memory_pred_dataset(shape, axes, expected_shape): +def test_correct_normalized_outputs(shape, axes, expected_shape): """Test that the dataset returns normalized images with singleton sample dimension.""" rng = np.random.default_rng(42) diff --git a/tests/dataset/test_in_memory_tiled_pred_dataset.py b/tests/dataset/test_in_memory_tiled_pred_dataset.py index e69de29bb..0d4f53817 100644 --- a/tests/dataset/test_in_memory_tiled_pred_dataset.py +++ b/tests/dataset/test_in_memory_tiled_pred_dataset.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest + +from careamics.config import InferenceConfig +from careamics.dataset import InMemoryTiledPredDataset + + +# TODO extract tiles is returning C(Z)YX and no singleton S! +@pytest.mark.parametrize( + "shape, axes, expected_shape", + [ + ((16, 16), "YX", (1, 16, 16)), + ((3, 16, 16), "CYX", (3, 16, 16)), + ((16, 16, 16), "ZYX", (1, 16, 16, 16)), + ((3, 16, 16, 16), "CZYX", (3, 16, 16, 16)), + ((4, 16, 16), "SYX", (1, 16, 16)), + ((4, 3, 16, 16), "SCYX", (3, 16, 16)), + ((4, 3, 16, 16, 16), "SCZYX", (3, 16, 16, 16)), + ], +) +def test_correct_normalized_outputs(shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + tile_size = (8, 8, 8) if "Z" in axes else (8, 8) + tile_overlap = (4, 4, 4) if "Z" in axes else (4, 4) + + # check expected length + n_tiles = np.prod( + np.ceil( + (expected_shape[1:] - np.array(tile_overlap)) + / (np.array(tile_size) - np.array(tile_overlap)) + ) + ).astype(int) + + # check number of samples + if "S" in axes: + # get index + idx = axes.index("S") + n_samples = shape[idx] + else: + n_samples = 1 + + # check number of channels + if "C" in axes: + # get index + idx = axes.index("C") + n_channels = shape[idx] + else: + n_channels = 1 + + # create array + array = 255 * rng.random(shape) + + # create config + config = InferenceConfig( + data_type="array", + axes=axes, + mean=np.mean(array), + std=np.std(array), + tile_size=tile_size, + tile_overlap=tile_overlap, + ) + + # create dataset + dataset = InMemoryTiledPredDataset(config, array) + + # check length + assert len(dataset) == n_samples * n_tiles + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img, _ = dataset[i] + + # check that it has the correct shape + assert img.shape == (n_channels,) + tile_size + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.25) + assert np.isclose(np.std(img), 1, atol=0.2) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2, _ = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/dataset/test_iterable_pred_dataset.py b/tests/dataset/test_iterable_pred_dataset.py index e69de29bb..c267e3ba4 100644 --- a/tests/dataset/test_iterable_pred_dataset.py +++ b/tests/dataset/test_iterable_pred_dataset.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest +import tifffile + +from careamics.config import InferenceConfig +from careamics.dataset import IterablePredDataset + + +@pytest.mark.parametrize( + "n_files, shape, axes, expected_shape", + [ + (1, (16, 16), "YX", (1, 1, 16, 16)), + (1, (3, 16, 16), "CYX", (1, 3, 16, 16)), + (1, (8, 16, 16), "ZYX", (1, 1, 8, 16, 16)), + (1, (3, 8, 16, 16), "CZYX", (1, 3, 8, 16, 16)), + (1, (4, 16, 16), "SYX", (1, 1, 16, 16)), + (1, (4, 3, 16, 16), "SCYX", (1, 3, 16, 16)), + (1, (4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), + (3, (16, 16), "YX", (1, 1, 16, 16)), + (3, (3, 16, 16), "CYX", (1, 3, 16, 16)), + (3, (8, 16, 16), "ZYX", (1, 1, 8, 16, 16)), + (3, (3, 8, 16, 16), "CZYX", (1, 3, 8, 16, 16)), + (3, (4, 16, 16), "SYX", (1, 1, 16, 16)), + (3, (4, 3, 16, 16), "SCYX", (1, 3, 16, 16)), + (3, (4, 3, 8, 16, 16), "SCZYX", (1, 3, 8, 16, 16)), + ], +) +def test_correct_normalized_outputs(tmp_path, n_files, shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + # check expected length + if "S" in axes: + # find index of S and check shape + idx = axes.index("S") + n_patches = shape[idx] + else: + n_patches = 1 + + # create array + new_shape = (n_files,) + shape + array = 255 * rng.random(new_shape) + + # create config + config = InferenceConfig( + data_type="tiff", + axes=axes, + mean=np.mean(array), + std=np.std(array), + ) + + files = [] + for i in range(n_files): + file = tmp_path / f"file_{i}.tif" + tifffile.imwrite(file, array[i]) + files.append(file) + + # create dataset + dataset = IterablePredDataset(config, files) + + # get all images + dataset = list(dataset) + + # check length + assert len(dataset) == n_files * n_patches + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img = dataset[i] + + # check that it has the correct shape + assert img.shape == expected_shape + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.1) + assert np.isclose(np.std(img), 1, atol=0.1) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2 = dataset[j] + assert not np.allclose(img, img2) diff --git a/tests/dataset/test_iterable_tiled_pred_dataset.py b/tests/dataset/test_iterable_tiled_pred_dataset.py index e69de29bb..dcc174571 100644 --- a/tests/dataset/test_iterable_tiled_pred_dataset.py +++ b/tests/dataset/test_iterable_tiled_pred_dataset.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest +import tifffile + +from careamics.config import InferenceConfig +from careamics.dataset import IterableTiledPredDataset + + +# TODO extract tiles is returning C(Z)YX and no singleton S! +@pytest.mark.parametrize( + "n_files, shape, axes, expected_shape", + [ + (1, (16, 16), "YX", (1, 16, 16)), + (1, (3, 16, 16), "CYX", (3, 16, 16)), + (1, (8, 16, 16), "ZYX", (1, 8, 16, 16)), + (1, (3, 8, 16, 16), "CZYX", (3, 8, 16, 16)), + (1, (4, 16, 16), "SYX", (1, 16, 16)), + (1, (4, 3, 16, 16), "SCYX", (3, 16, 16)), + (1, (4, 3, 8, 16, 16), "SCZYX", (3, 8, 16, 16)), + (3, (16, 16), "YX", (1, 16, 16)), + (3, (3, 16, 16), "CYX", (3, 16, 16)), + (3, (8, 16, 16), "ZYX", (1, 8, 16, 16)), + (3, (3, 8, 16, 16), "CZYX", (3, 8, 16, 16)), + (3, (4, 16, 16), "SYX", (1, 16, 16)), + (3, (4, 3, 16, 16), "SCYX", (3, 16, 16)), + (3, (4, 3, 8, 16, 16), "SCZYX", (3, 8, 16, 16)), + ], +) +def test_correct_normalized_outputs(tmp_path, n_files, shape, axes, expected_shape): + """Test that the dataset returns normalized images with singleton + sample dimension.""" + rng = np.random.default_rng(42) + + tile_size = (8, 8, 8) if "Z" in axes else (8, 8) + tile_overlap = (4, 4, 4) if "Z" in axes else (4, 4) + + # check expected length + n_tiles = np.prod( + np.ceil( + (expected_shape[1:] - np.array(tile_overlap)) + / (np.array(tile_size) - np.array(tile_overlap)) + ) + ).astype(int) + + # check number of samples + if "S" in axes: + # get index + idx = axes.index("S") + n_samples = shape[idx] + else: + n_samples = 1 + + # check number of channels + if "C" in axes: + # get index + idx = axes.index("C") + n_channels = shape[idx] + else: + n_channels = 1 + + # create array + new_shape = (n_files,) + shape + array = 255 * rng.random(new_shape) + + # create config + config = InferenceConfig( + data_type="tiff", + axes=axes, + mean=np.mean(array), + std=np.std(array), + tile_size=tile_size, + tile_overlap=tile_overlap, + ) + + files = [] + for i in range(n_files): + file = tmp_path / f"file_{i}.tif" + tifffile.imwrite(file, array[i]) + files.append(file) + + # create dataset + dataset = IterableTiledPredDataset(config, files) + + # get all images + dataset = list(dataset) + + # check length + assert len(dataset) == n_files * n_samples * n_tiles + + # check that the dataset returns normalized images + for i in range(len(dataset)): + img, _ = dataset[i] + + # check that it has the correct shape + assert img.shape == (n_channels,) + tile_size + + # check that the image is normalized + assert np.isclose(np.mean(img), 0, atol=0.25) + assert np.isclose(np.std(img), 1, atol=0.2) + + # check that they are independent slices + for j in range(i + 1, len(dataset)): + img2, _ = dataset[j] + assert not np.allclose(img, img2)