Skip to content

Commit

Permalink
(refac): split the datasets into their own modules
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 5, 2024
1 parent 2c4f655 commit 5770d7c
Show file tree
Hide file tree
Showing 10 changed files with 611 additions and 541 deletions.
18 changes: 7 additions & 11 deletions src/careamics/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

__all__ = [
"InMemoryDataset",
"InMemoryPredictionDataset",
"InMemoryPredDataset",
"InMemoryTiledPredictionDataset",
"PathIterableDataset",
"IterableTiledPredictionDataset",
"IterablePredictionDataset",
]

from .in_memory_dataset import (
InMemoryDataset,
InMemoryPredictionDataset,
InMemoryTiledPredictionDataset,
)
from .iterable_dataset import (
IterablePredictionDataset,
IterableTiledPredictionDataset,
PathIterableDataset,
)
from .in_memory_dataset import InMemoryDataset
from .in_memory_pred_dataset import InMemoryPredDataset
from .in_memory_tiled_pred_dataset import InMemoryTiledPredictionDataset
from .iterable_dataset import PathIterableDataset
from .iterable_pred_dataset import IterablePredictionDataset
from .iterable_tiled_pred_dataset import IterableTiledPredictionDataset
2 changes: 2 additions & 0 deletions src/careamics/dataset/dataset_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
"read_tiff",
"get_read_func",
"read_zarr",
"iterate_over_files",
]


from .dataset_utils import reshape_array
from .file_utils import get_files_size, list_files, validate_source_target_files
from .iterate_over_files import iterate_over_files
from .read_tiff import read_tiff
from .read_utils import get_read_func
from .read_zarr import read_zarr
83 changes: 83 additions & 0 deletions src/careamics/dataset/dataset_utils/iterate_over_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Function to iterate over files."""

from __future__ import annotations

from pathlib import Path
from typing import Callable, Generator, List, Optional, Tuple, Union

import numpy as np
from torch.utils.data import get_worker_info

from careamics.config import DataConfig, InferenceConfig
from careamics.utils.logging import get_logger

from .dataset_utils import reshape_array
from .read_tiff import read_tiff

logger = get_logger(__name__)


def iterate_over_files(
data_config: Union[DataConfig, InferenceConfig],
data_files: List[Path],
target_files: Optional[List[Path]] = None,
read_source_func: Callable = read_tiff,
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
"""Iterate over data source and yield whole reshaped images.
Parameters
----------
data_config : Union[DataConfig, InferenceConfig]
Data configuration.
data_files : List[Path]
List of data files.
target_files : Optional[List[Path]]
List of target files, by default None.
read_source_func : Optional[Callable]
Function to read the source, by default read_tiff.
Yields
------
np.ndarray
Image.
"""
# When num_workers > 0, each worker process will have a different copy of the
# dataset object
# Configuring each copy independently to avoid having duplicate data returned
# from the workers
worker_info = get_worker_info()
worker_id = worker_info.id if worker_info is not None else 0
num_workers = worker_info.num_workers if worker_info is not None else 1

# iterate over the files
for i, filename in enumerate(data_files):
# retrieve file corresponding to the worker id
if i % num_workers == worker_id:
try:
# read data
sample = read_source_func(filename, data_config.axes)

# reshape array
reshaped_sample = reshape_array(sample, data_config.axes)

# read target, if available
if target_files is not None:
if filename.name != target_files[i].name:
raise ValueError(
f"File {filename} does not match target file "
f"{target_files[i]}. Have you passed sorted "
f"arrays?"
)

# read target
target = read_source_func(target_files[i], data_config.axes)

# reshape target
reshaped_target = reshape_array(target, data_config.axes)

yield reshaped_sample, reshaped_target
else:
yield reshaped_sample, None

except Exception as e:
logger.error(f"Error reading file {filename}: {e}")
229 changes: 2 additions & 227 deletions src/careamics/dataset/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@

from careamics.transforms import Compose

from ..config import DataConfig, InferenceConfig
from ..config.tile_information import TileInformation
from ..config.transformations import NormalizeModel
from ..config import DataConfig
from ..utils.logging import get_logger
from .dataset_utils import read_tiff, reshape_array
from .dataset_utils import read_tiff
from .patching.patching import (
prepare_patches_supervised,
prepare_patches_supervised_array,
prepare_patches_unsupervised,
prepare_patches_unsupervised_array,
)
from .patching.tiled_patching import extract_tiles

logger = get_logger(__name__)

Expand Down Expand Up @@ -270,225 +267,3 @@ def split_dataset(
dataset.patch_targets = val_targets

return dataset


class InMemoryPredictionDataset(Dataset):
"""Simple prediction dataset returning images along the sample axis.
Parameters
----------
prediction_config : InferenceConfig
Prediction configuration.
inputs : np.ndarray
Input data.
data_target : Optional[np.ndarray], optional
Target data, by default None.
read_source_func : Optional[Callable], optional
Read source function for custom types, by default read_tiff.
"""

def __init__(
self,
prediction_config: InferenceConfig,
inputs: np.ndarray,
data_target: Optional[np.ndarray] = None,
read_source_func: Optional[Callable] = read_tiff,
) -> None:
"""Constructor.
Parameters
----------
prediction_config : InferenceConfig
Prediction configuration.
inputs : np.ndarray
Input data.
data_target : Optional[np.ndarray], optional
Target data, by default None.
read_source_func : Optional[Callable], optional
Read source function for custom types, by default read_tiff.
Raises
------
ValueError
If data_path is not a directory.
"""
self.pred_config = prediction_config
self.input_array = inputs
self.axes = self.pred_config.axes
self.tile_size = self.pred_config.tile_size
self.tile_overlap = self.pred_config.tile_overlap
self.mean = self.pred_config.mean
self.std = self.pred_config.std
self.data_target = data_target
self.mean, self.std = self.pred_config.mean, self.pred_config.std

# tiling only if both tile size and overlap are provided
self.tiling = self.tile_size is not None and self.tile_overlap is not None

# read function
self.read_source_func = read_source_func

# Reshape data
self.data = reshape_array(self.input_array, self.axes)

# get transforms
self.patch_transform = Compose(
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
)

def __len__(self) -> int:
"""
Return the length of the dataset.
Returns
-------
int
Length of the dataset.
"""
return len(self.data)

def __getitem__(self, index: int) -> np.ndarray:
"""
Return the patch corresponding to the provided index.
Parameters
----------
index : int
Index of the patch to return.
Returns
-------
np.ndarray
Transformed patch.
"""
transformed_patch, _ = self.patch_transform(patch=self.data[[index]])

return transformed_patch


class InMemoryTiledPredictionDataset(Dataset):
"""Prediction dataset storing data in memory and returning tiles of each image.
Parameters
----------
prediction_config : InferenceConfig
Prediction configuration.
inputs : np.ndarray
Input data.
data_target : Optional[np.ndarray], optional
Target data, by default None.
read_source_func : Optional[Callable], optional
Read source function for custom types, by default read_tiff.
"""

def __init__(
self,
prediction_config: InferenceConfig,
inputs: np.ndarray,
data_target: Optional[np.ndarray] = None,
read_source_func: Optional[Callable] = read_tiff,
) -> None:
"""Constructor.
Parameters
----------
prediction_config : InferenceConfig
Prediction configuration.
inputs : np.ndarray
Input data.
data_target : Optional[np.ndarray], optional
Target data, by default None.
read_source_func : Optional[Callable], optional
Read source function for custom types, by default read_tiff.
Raises
------
ValueError
If data_path is not a directory.
"""
if (
prediction_config.tile_size is None
or prediction_config.tile_overlap is None
):
raise ValueError(
"Tile size and overlap must be provided to use the tiled prediction "
"dataset."
)

self.pred_config = prediction_config
self.input_array = inputs
self.axes = self.pred_config.axes
self.tile_size = prediction_config.tile_size
self.tile_overlap = prediction_config.tile_overlap
self.mean = self.pred_config.mean
self.std = self.pred_config.std
self.data_target = data_target

# read function
self.read_source_func = read_source_func

# Generate patches
self.data = self._prepare_tiles()
self.mean, self.std = self.pred_config.mean, self.pred_config.std

# get transforms
self.patch_transform = Compose(
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
)

def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
"""
Iterate over data source and create an array of patches.
Returns
-------
List[XArrayTile]
List of tiles.
"""
# reshape array
reshaped_sample = reshape_array(self.input_array, self.axes)

# generate patches, which returns a generator
patch_generator = extract_tiles(
arr=reshaped_sample,
tile_size=self.tile_size,
overlaps=self.tile_overlap,
)
patches_list = list(patch_generator)

if len(patches_list) == 0:
raise ValueError("No tiles generated, ")

return patches_list

def __len__(self) -> int:
"""
Return the length of the dataset.
Returns
-------
int
Length of the dataset.
"""
return len(self.data)

def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
"""
Return the patch corresponding to the provided index.
Parameters
----------
index : int
Index of the patch to return.
Returns
-------
Tuple[np.ndarray, TileInformation]
Transformed patch.
"""
tile_array, tile_info = self.data[index]

# Apply transforms
transformed_tile, _ = self.patch_transform(patch=tile_array)

return transformed_tile, tile_info
Loading

0 comments on commit 5770d7c

Please sign in to comment.