Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor prediction pipeline #131

Merged
merged 16 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _create_data_for_bmz(
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)
input_patch = denormalize(input_patch)

elif self.train_datamodule is not None:
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
Expand All @@ -710,7 +710,7 @@ def _create_data_for_bmz(
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)
input_patch = denormalize(input_patch)
else:
# create a random input array
input_patch = np.random.normal(
Expand Down
74 changes: 21 additions & 53 deletions src/careamics/config/tile_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from __future__ import annotations

from typing import Optional, Tuple
from typing import Tuple

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from pydantic import BaseModel, ConfigDict, field_validator


class TileInformation(BaseModel):
Expand All @@ -13,15 +13,17 @@

This model is used to represent the information required to stitch back a tile into
a larger image. It is used throughout the prediction pipeline of CAREamics.

Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
contain singleton dimensions.
"""

model_config = ConfigDict(validate_default=True)

array_shape: Tuple[int, ...]
tiled: bool = False
last_tile: bool = False
overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
overlap_crop_coords: Tuple[Tuple[int, ...], ...]
stitch_coords: Tuple[Tuple[int, ...], ...]

@field_validator("array_shape")
@classmethod
Expand All @@ -48,59 +50,25 @@
raise ValueError("Array shape must not contain singleton dimensions.")
return v

@field_validator("last_tile")
@classmethod
def only_if_tiled(cls, v: bool, values: ValidationInfo):
"""
Check that the last tile flag is only set if tiling is enabled.
def __eq__(self, other_tile: object):
"""Check if two tile information objects are equal.

Parameters
----------
v : bool
Last tile flag.
values : ValidationInfo
Validation information.
other_tile : object
Tile information object to compare with.

Returns
-------
bool
The last tile flag.
"""
if not values.data["tiled"]:
return False
return v

@field_validator("overlap_crop_coords", "stitch_coords")
@classmethod
def mandatory_if_tiled(
cls, v: Optional[Tuple[int, ...]], values: ValidationInfo
) -> Optional[Tuple[int, ...]]:
"""
Check that the coordinates are not `None` if tiling is enabled.

The method also return `None` if tiling is not enabled.

Parameters
----------
v : Optional[Tuple[int, ...]]
Coordinates to check.
values : ValidationInfo
Validation information.

Returns
-------
Optional[Tuple[int, ...]]
The coordinates if tiling is enabled, otherwise `None`.

Raises
------
ValueError
If the coordinates are `None` and tiling is enabled.
Whether the two tile information objects are equal.
"""
if values.data["tiled"]:
if v is None:
raise ValueError("Value must be specified if tiling is enabled.")

return v
else:
return None
if not isinstance(other_tile, TileInformation):
return NotImplemented

Check warning on line 67 in src/careamics/config/tile_information.py

View check run for this annotation

Codecov / codecov/patch

src/careamics/config/tile_information.py#L67

Added line #L67 was not covered by tests

return (
self.array_shape == other_tile.array_shape
and self.last_tile == other_tile.last_tile
and self.overlap_crop_coords == other_tile.overlap_crop_coords
and self.stitch_coords == other_tile.stitch_coords
)
2 changes: 1 addition & 1 deletion src/careamics/config/validators/validator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
If the value is not a power of 2.
"""
if value < 8:
raise ValueError(f"Value must be non-zero positive (got {value}).")
raise ValueError(f"Value must be greater than 8 (got {value}).")

if (value & (value - 1)) != 0:
raise ValueError(f"Value must be a power of 2 (got {value}).")
Expand Down
13 changes: 12 additions & 1 deletion src/careamics/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
"""Dataset module."""

__all__ = ["InMemoryDataset", "PathIterableDataset"]
__all__ = [
"InMemoryDataset",
"InMemoryPredDataset",
"InMemoryTiledPredDataset",
"PathIterableDataset",
"IterableTiledPredDataset",
"IterablePredDataset",
]

from .in_memory_dataset import InMemoryDataset
from .in_memory_pred_dataset import InMemoryPredDataset
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
from .iterable_dataset import PathIterableDataset
from .iterable_pred_dataset import IterablePredDataset
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
2 changes: 2 additions & 0 deletions src/careamics/dataset/dataset_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
"read_tiff",
"get_read_func",
"read_zarr",
"iterate_over_files",
]


from .dataset_utils import reshape_array
from .file_utils import get_files_size, list_files, validate_source_target_files
from .iterate_over_files import iterate_over_files
from .read_tiff import read_tiff
from .read_utils import get_read_func
from .read_zarr import read_zarr
2 changes: 1 addition & 1 deletion src/careamics/dataset/dataset_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def list_files(
data_type: Union[str, SupportedData],
extension_filter: str = "",
) -> List[Path]:
"""Create a recursive list of files in `data_path`.
"""List recursively files in `data_path` and return a sorted list.

If `data_path` is a file, its name is validated against the `data_type` using
`fnmatch`, and the method returns `data_path` itself.
Expand Down
83 changes: 83 additions & 0 deletions src/careamics/dataset/dataset_utils/iterate_over_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Function to iterate over files."""

from __future__ import annotations

from pathlib import Path
from typing import Callable, Generator, List, Optional, Tuple, Union

import numpy as np
from torch.utils.data import get_worker_info

from careamics.config import DataConfig, InferenceConfig
from careamics.utils.logging import get_logger

from .dataset_utils import reshape_array
from .read_tiff import read_tiff

logger = get_logger(__name__)


def iterate_over_files(
data_config: Union[DataConfig, InferenceConfig],
data_files: List[Path],
target_files: Optional[List[Path]] = None,
read_source_func: Callable = read_tiff,
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
"""Iterate over data source and yield whole reshaped images.

Parameters
----------
data_config : Union[DataConfig, InferenceConfig]
Data configuration.
data_files : List[Path]
List of data files.
target_files : Optional[List[Path]]
List of target files, by default None.
read_source_func : Optional[Callable]
Function to read the source, by default read_tiff.

Yields
------
np.ndarray
Image.
"""
# When num_workers > 0, each worker process will have a different copy of the
# dataset object
# Configuring each copy independently to avoid having duplicate data returned
# from the workers
worker_info = get_worker_info()
worker_id = worker_info.id if worker_info is not None else 0
num_workers = worker_info.num_workers if worker_info is not None else 1

# iterate over the files
for i, filename in enumerate(data_files):
# retrieve file corresponding to the worker id
if i % num_workers == worker_id:
try:
# read data
sample = read_source_func(filename, data_config.axes)

# reshape array
reshaped_sample = reshape_array(sample, data_config.axes)

# read target, if available
if target_files is not None:
if filename.name != target_files[i].name:
raise ValueError(

Check warning on line 66 in src/careamics/dataset/dataset_utils/iterate_over_files.py

View check run for this annotation

Codecov / codecov/patch

src/careamics/dataset/dataset_utils/iterate_over_files.py#L66

Added line #L66 was not covered by tests
f"File {filename} does not match target file "
f"{target_files[i]}. Have you passed sorted "
f"arrays?"
)

# read target
target = read_source_func(target_files[i], data_config.axes)

# reshape target
reshaped_target = reshape_array(target, data_config.axes)

yield reshaped_sample, reshaped_target
else:
yield reshaped_sample, None

except Exception as e:
logger.error(f"Error reading file {filename}: {e}")

Check warning on line 83 in src/careamics/dataset/dataset_utils/iterate_over_files.py

View check run for this annotation

Codecov / codecov/patch

src/careamics/dataset/dataset_utils/iterate_over_files.py#L83

Added line #L83 was not covered by tests
9 changes: 0 additions & 9 deletions src/careamics/dataset/dataset_utils/read_tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
else:
raise ValueError(f"File {file_path} is not a valid tiff.")

# check dimensions
# TODO or should this really be done here? probably in the LightningDataModule
# TODO this should also be centralized somewhere else (validate_dimensions)
if len(array.shape) < 2 or len(array.shape) > 6:
raise ValueError(
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
f"file {file_path})."
)

return array
Loading
Loading