-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### Description > **tldr**: > - Split prediction datasets into tiled and non-tiled > - Simplify stitching by passing `TileInformation` all along > - Fix #125 > - Change `total` memory check to `available` When not tiling prediction, the images are forced through the tiling pipeline (with a `TileInformation` class being passed along the images), and it makes the debugging complex. This PR splits the prediction datasets into tiled and not-tiled. I also changed the `total` memory into `available` for the switch between in memory and iterable datasets during training, as it better represents what can be loaded in memory. Finally, I simplified the stitching and prediction pipeline by passing the `TileInformation` further. - **What**: Refactor prediction datasets into tiled and non-tiled datasets, simplify stitching. - **Why**: Avoids forcing non-tiled prediction through the same complex pipeline as the tiled one. - **How**: Split the two features (tiled and non tiled predictions) into two datasets. ### Changes Made - **Added**: - *dataset/iterable_pred_dataset.py* - *dataset/iterable_tiled_pred_dataset.py* - *dataset/in_memory_pred_dataset.py* - *dataset/in_memory_tiled_pred_dataset.py* - **Modified**: `get_ram_size` now looks at available memory. - **Removed**: Removed useless calls to `sort` in datasets, `is_tile` in `TilingInformation`. ### Related issues This PR fixes #125. ### Notes This PR will create merge issues with #134. Currently, there are two issues remaining: - `extract_tile` returns C(Z)YX, while the non tiling datasets always return SC(Z)YX with a singleton dimension - it is not clear when we should, and when, cast the `Tensors` into `np.ndarray`. I tried to make it happen in the same place (in the prediction loop), but that is not entirely solved. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features)
- Loading branch information
1 parent
763d965
commit 9c829b7
Showing
41 changed files
with
1,334 additions
and
654 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,17 @@ | ||
"""Dataset module.""" | ||
|
||
__all__ = ["InMemoryDataset", "PathIterableDataset"] | ||
__all__ = [ | ||
"InMemoryDataset", | ||
"InMemoryPredDataset", | ||
"InMemoryTiledPredDataset", | ||
"PathIterableDataset", | ||
"IterableTiledPredDataset", | ||
"IterablePredDataset", | ||
] | ||
|
||
from .in_memory_dataset import InMemoryDataset | ||
from .in_memory_pred_dataset import InMemoryPredDataset | ||
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset | ||
from .iterable_dataset import PathIterableDataset | ||
from .iterable_pred_dataset import IterablePredDataset | ||
from .iterable_tiled_pred_dataset import IterableTiledPredDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""Function to iterate over files.""" | ||
|
||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import Callable, Generator, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
from torch.utils.data import get_worker_info | ||
|
||
from careamics.config import DataConfig, InferenceConfig | ||
from careamics.utils.logging import get_logger | ||
|
||
from .dataset_utils import reshape_array | ||
from .read_tiff import read_tiff | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
def iterate_over_files( | ||
data_config: Union[DataConfig, InferenceConfig], | ||
data_files: List[Path], | ||
target_files: Optional[List[Path]] = None, | ||
read_source_func: Callable = read_tiff, | ||
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]: | ||
"""Iterate over data source and yield whole reshaped images. | ||
Parameters | ||
---------- | ||
data_config : Union[DataConfig, InferenceConfig] | ||
Data configuration. | ||
data_files : List[Path] | ||
List of data files. | ||
target_files : Optional[List[Path]] | ||
List of target files, by default None. | ||
read_source_func : Optional[Callable] | ||
Function to read the source, by default read_tiff. | ||
Yields | ||
------ | ||
np.ndarray | ||
Image. | ||
""" | ||
# When num_workers > 0, each worker process will have a different copy of the | ||
# dataset object | ||
# Configuring each copy independently to avoid having duplicate data returned | ||
# from the workers | ||
worker_info = get_worker_info() | ||
worker_id = worker_info.id if worker_info is not None else 0 | ||
num_workers = worker_info.num_workers if worker_info is not None else 1 | ||
|
||
# iterate over the files | ||
for i, filename in enumerate(data_files): | ||
# retrieve file corresponding to the worker id | ||
if i % num_workers == worker_id: | ||
try: | ||
# read data | ||
sample = read_source_func(filename, data_config.axes) | ||
|
||
# reshape array | ||
reshaped_sample = reshape_array(sample, data_config.axes) | ||
|
||
# read target, if available | ||
if target_files is not None: | ||
if filename.name != target_files[i].name: | ||
raise ValueError( | ||
f"File {filename} does not match target file " | ||
f"{target_files[i]}. Have you passed sorted " | ||
f"arrays?" | ||
) | ||
|
||
# read target | ||
target = read_source_func(target_files[i], data_config.axes) | ||
|
||
# reshape target | ||
reshaped_target = reshape_array(target, data_config.axes) | ||
|
||
yield reshaped_sample, reshaped_target | ||
else: | ||
yield reshaped_sample, None | ||
|
||
except Exception as e: | ||
logger.error(f"Error reading file {filename}: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.