Skip to content

Commit

Permalink
Merge branch 'main' into jd/chore/drop_py8
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Jun 12, 2024
2 parents 1b3e5f2 + 9c829b7 commit aba79d0
Show file tree
Hide file tree
Showing 41 changed files with 1,346 additions and 664 deletions.
4 changes: 2 additions & 2 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _create_data_for_bmz(
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)
input_patch = denormalize(input_patch)

elif self.train_datamodule is not None:
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
Expand All @@ -710,7 +710,7 @@ def _create_data_for_bmz(
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)
input_patch = denormalize(input_patch)
else:
# create a random input array
input_patch = np.random.normal(
Expand Down
77 changes: 22 additions & 55 deletions src/careamics/config/tile_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from __future__ import annotations

from typing import Optional

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from pydantic import BaseModel, ConfigDict, field_validator


class TileInformation(BaseModel):
Expand All @@ -13,15 +12,17 @@ class TileInformation(BaseModel):
This model is used to represent the information required to stitch back a tile into
a larger image. It is used throughout the prediction pipeline of CAREamics.
Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
contain singleton dimensions.
"""

model_config = ConfigDict(validate_default=True)

array_shape: tuple[int, ...]
tiled: bool = False
last_tile: bool = False
overlap_crop_coords: Optional[tuple[tuple[int, ...], ...]] = Field(default=None)
stitch_coords: Optional[tuple[tuple[int, ...], ...]] = Field(default=None)
overlap_crop_coords: tuple[tuple[int, ...], ...]
stitch_coords: tuple[tuple[int, ...], ...]

@field_validator("array_shape")
@classmethod
Expand All @@ -31,12 +32,12 @@ def no_singleton_dimensions(cls, v: tuple[int, ...]):
Parameters
----------
v : tuple[int, ...]
v : tuple of int
Array shape to check.
Returns
-------
tuple[int, ...]
tuple of int
The array shape if it does not contain singleton dimensions.
Raises
Expand All @@ -48,59 +49,25 @@ def no_singleton_dimensions(cls, v: tuple[int, ...]):
raise ValueError("Array shape must not contain singleton dimensions.")
return v

@field_validator("last_tile")
@classmethod
def only_if_tiled(cls, v: bool, values: ValidationInfo):
"""
Check that the last tile flag is only set if tiling is enabled.
def __eq__(self, other_tile: object):
"""Check if two tile information objects are equal.
Parameters
----------
v : bool
Last tile flag.
values : ValidationInfo
Validation information.
other_tile : object
Tile information object to compare with.
Returns
-------
bool
The last tile flag.
"""
if not values.data["tiled"]:
return False
return v

@field_validator("overlap_crop_coords", "stitch_coords")
@classmethod
def mandatory_if_tiled(
cls, v: Optional[tuple[int, ...]], values: ValidationInfo
) -> Optional[tuple[int, ...]]:
"""
Check that the coordinates are not `None` if tiling is enabled.
The method also return `None` if tiling is not enabled.
Parameters
----------
v : tuple[int, ...] or None
Coordinates to check.
values : ValidationInfo
Validation information.
Returns
-------
tuple[int, ...] or None
The coordinates if tiling is enabled, otherwise `None`.
Raises
------
ValueError
If the coordinates are `None` and tiling is enabled.
Whether the two tile information objects are equal.
"""
if values.data["tiled"]:
if v is None:
raise ValueError("Value must be specified if tiling is enabled.")

return v
else:
return None
if not isinstance(other_tile, TileInformation):
return NotImplemented

return (
self.array_shape == other_tile.array_shape
and self.last_tile == other_tile.last_tile
and self.overlap_crop_coords == other_tile.overlap_crop_coords
and self.stitch_coords == other_tile.stitch_coords
)
2 changes: 1 addition & 1 deletion src/careamics/config/validators/validator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
If the value is not a power of 2.
"""
if value < 8:
raise ValueError(f"Value must be non-zero positive (got {value}).")
raise ValueError(f"Value must be greater than 8 (got {value}).")

if (value & (value - 1)) != 0:
raise ValueError(f"Value must be a power of 2 (got {value}).")
Expand Down
13 changes: 12 additions & 1 deletion src/careamics/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
"""Dataset module."""

__all__ = ["InMemoryDataset", "PathIterableDataset"]
__all__ = [
"InMemoryDataset",
"InMemoryPredDataset",
"InMemoryTiledPredDataset",
"PathIterableDataset",
"IterableTiledPredDataset",
"IterablePredDataset",
]

from .in_memory_dataset import InMemoryDataset
from .in_memory_pred_dataset import InMemoryPredDataset
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
from .iterable_dataset import PathIterableDataset
from .iterable_pred_dataset import IterablePredDataset
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
2 changes: 2 additions & 0 deletions src/careamics/dataset/dataset_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
"read_tiff",
"get_read_func",
"read_zarr",
"iterate_over_files",
]


from .dataset_utils import reshape_array
from .file_utils import get_files_size, list_files, validate_source_target_files
from .iterate_over_files import iterate_over_files
from .read_tiff import read_tiff
from .read_utils import get_read_func
from .read_zarr import read_zarr
2 changes: 1 addition & 1 deletion src/careamics/dataset/dataset_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def list_files(
data_type: Union[str, SupportedData],
extension_filter: str = "",
) -> List[Path]:
"""Create a recursive list of files in `data_path`.
"""List recursively files in `data_path` and return a sorted list.
If `data_path` is a file, its name is validated against the `data_type` using
`fnmatch`, and the method returns `data_path` itself.
Expand Down
83 changes: 83 additions & 0 deletions src/careamics/dataset/dataset_utils/iterate_over_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Function to iterate over files."""

from __future__ import annotations

from pathlib import Path
from typing import Callable, Generator, List, Optional, Tuple, Union

import numpy as np
from torch.utils.data import get_worker_info

from careamics.config import DataConfig, InferenceConfig
from careamics.utils.logging import get_logger

from .dataset_utils import reshape_array
from .read_tiff import read_tiff

logger = get_logger(__name__)


def iterate_over_files(
data_config: Union[DataConfig, InferenceConfig],
data_files: List[Path],
target_files: Optional[List[Path]] = None,
read_source_func: Callable = read_tiff,
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
"""Iterate over data source and yield whole reshaped images.
Parameters
----------
data_config : Union[DataConfig, InferenceConfig]
Data configuration.
data_files : List[Path]
List of data files.
target_files : Optional[List[Path]]
List of target files, by default None.
read_source_func : Optional[Callable]
Function to read the source, by default read_tiff.
Yields
------
np.ndarray
Image.
"""
# When num_workers > 0, each worker process will have a different copy of the
# dataset object
# Configuring each copy independently to avoid having duplicate data returned
# from the workers
worker_info = get_worker_info()
worker_id = worker_info.id if worker_info is not None else 0
num_workers = worker_info.num_workers if worker_info is not None else 1

# iterate over the files
for i, filename in enumerate(data_files):
# retrieve file corresponding to the worker id
if i % num_workers == worker_id:
try:
# read data
sample = read_source_func(filename, data_config.axes)

# reshape array
reshaped_sample = reshape_array(sample, data_config.axes)

# read target, if available
if target_files is not None:
if filename.name != target_files[i].name:
raise ValueError(
f"File {filename} does not match target file "
f"{target_files[i]}. Have you passed sorted "
f"arrays?"
)

# read target
target = read_source_func(target_files[i], data_config.axes)

# reshape target
reshaped_target = reshape_array(target, data_config.axes)

yield reshaped_sample, reshaped_target
else:
yield reshaped_sample, None

except Exception as e:
logger.error(f"Error reading file {filename}: {e}")
9 changes: 0 additions & 9 deletions src/careamics/dataset/dataset_utils/read_tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
else:
raise ValueError(f"File {file_path} is not a valid tiff.")

# check dimensions
# TODO or should this really be done here? probably in the LightningDataModule
# TODO this should also be centralized somewhere else (validate_dimensions)
if len(array.shape) < 2 or len(array.shape) > 6:
raise ValueError(
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
f"file {file_path})."
)

return array
Loading

0 comments on commit aba79d0

Please sign in to comment.