Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CatEek committed Jun 12, 2024
1 parent 4bcc390 commit b8e382f
Show file tree
Hide file tree
Showing 18 changed files with 162 additions and 256 deletions.
48 changes: 29 additions & 19 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.transforms import Denormalize
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
Expand Down Expand Up @@ -595,6 +594,15 @@ def predict(
"No configuration found. Train a model or load from a "
"checkpoint before predicting."
)

# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)

# create predict config, reuse training config if parameters missing
prediction_config = create_inference_configuration(
configuration=self.cfg,
Expand All @@ -603,11 +611,7 @@ def predict(
data_type=data_type,
axes=axes,
tta_transforms=tta_transforms,
batch_size=(
batch_size
if batch_size is not None
else self.train_datamodule.batch_size
),
batch_size=batch_size,
)

# remove batch from dataloader parameters (priority given to config)
Expand Down Expand Up @@ -697,29 +701,35 @@ def _create_data_for_bmz(

# generate images, priority is given to the prediction data module
if self.pred_datamodule is not None:
# unpack a batch, ignore masks or targets
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
# unpack a batch, ignore aux if present
if self.pred_datamodule.tiled:
input_patch, *_ = next(
iter(self.pred_datamodule.predict_dataloader())
)
else:
input_patch = next(iter(self.pred_datamodule.predict_dataloader()))

# convert torch.Tensor to numpy
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
image_means=self.cfg.data_config.image_mean,
image_stds=self.cfg.data_config.image_std,
)
input_patch = denormalize(input_patch)
# # denormalize
# denormalize = Denormalize(
# image_means=self.cfg.data_config.image_mean,
# image_stds=self.cfg.data_config.image_std,
# )
# input_patch = denormalize(input_patch)

elif self.train_datamodule is not None:
# unpack a batch, ignore aux if present
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
image_means=self.cfg.data_config.image_mean,
image_stds=self.cfg.data_config.image_std,
)
input_patch = denormalize(input_patch)
# denormalize = Denormalize(
# image_means=self.cfg.data_config.image_mean,
# image_stds=self.cfg.data_config.image_std,
# )
# input_patch = denormalize(input_patch)
else:
# create a random input array
input_patch = np.random.normal(
Expand Down
9 changes: 6 additions & 3 deletions src/careamics/dataset/in_memory_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ def __init__(
self.pred_config = prediction_config
self.input_array = inputs
self.axes = self.pred_config.axes
self.mean, self.std = self.pred_config.mean, self.pred_config.std
self.image_means = self.pred_config.image_mean
self.image_stds = self.pred_config.image_std

# Reshape data
self.data = reshape_array(self.input_array, self.axes)

# get transforms
self.patch_transform = Compose(
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
transform_list=[
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
],
)

def __len__(self) -> int:
Expand Down Expand Up @@ -80,6 +83,6 @@ def __getitem__(self, index: int) -> np.ndarray:
np.ndarray
Transformed patch.
"""
transformed_patch, _ = self.patch_transform(patch=self.data[[index]])
transformed_patch, _ = self.patch_transform(patch=self.data[index])

return transformed_patch
9 changes: 5 additions & 4 deletions src/careamics/dataset/in_memory_tiled_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,17 @@ def __init__(
self.axes = self.pred_config.axes
self.tile_size = prediction_config.tile_size
self.tile_overlap = prediction_config.tile_overlap
self.mean = self.pred_config.mean
self.std = self.pred_config.std
self.image_means = self.pred_config.image_mean
self.image_stds = self.pred_config.image_std

# Generate patches
self.data = self._prepare_tiles()
self.mean, self.std = self.pred_config.mean, self.pred_config.std

# get transforms
self.patch_transform = Compose(
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
transform_list=[
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
],
)

def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
Expand Down
142 changes: 1 addition & 141 deletions src/careamics/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,9 @@ def _calculate_mean_and_std(self) -> PatchedOutput:
image_means, image_stds, target_means, target_stds = 0, 0, 0, 0
num_samples = 0

for sample, target in _iterate_over_files(
for sample, target in iterate_over_files(
self.data_config, self.data_files, self.target_files, self.read_source_func
):
sample = reshape_array(sample, self.data_config.axes)
target = (
None if target is None else reshape_array(target, self.data_config.axes)
)

sample_mean, sample_std = compute_normalization_stats(sample)
image_means += sample_mean
image_stds += sample_std
Expand Down Expand Up @@ -296,138 +291,3 @@ def split_dataset(
dataset.target_files = val_target_files

return dataset


class IterablePredictionDataset(IterableDataset):
"""
Prediction dataset.
Parameters
----------
prediction_config : InferenceConfig
Inference configuration.
src_files : List[Path]
List of data files.
read_source_func : Callable, optional
Read source function for custom types, by default read_tiff.
**kwargs : Any
Additional keyword arguments, unused.
Attributes
----------
data_path : Union[str, Path]
Path to the data, must be a directory.
axes : str
Description of axes in format STCZYX.
mean : Optional[float], optional
Expected mean of the dataset, by default None.
std : Optional[float], optional
Expected standard deviation of the dataset, by default None.
patch_transform : Optional[Callable], optional
Patch transform callable, by default None.
"""

def __init__(
self,
prediction_config: InferenceConfig,
src_files: List[Path],
read_source_func: Callable = read_tiff,
**kwargs: Any,
) -> None:
"""Constructor.
Parameters
----------
prediction_config : InferenceConfig
Inference configuration.
src_files : List[Path]
List of data files.
read_source_func : Callable, optional
Read source function for custom types, by default read_tiff.
**kwargs : Any
Additional keyword arguments, unused.
Raises
------
ValueError
If mean and std are not provided in the inference configuration.
"""
self.prediction_config = prediction_config
self.data_files = src_files
self.axes = prediction_config.axes
self.tile_size = self.prediction_config.tile_size
self.tile_overlap = self.prediction_config.tile_overlap
self.read_source_func = read_source_func
self.image_means = self.prediction_config.image_mean
self.image_stds = self.prediction_config.image_std

# tile only if both tile size and overlaps are provided
self.tile = self.tile_size is not None and self.tile_overlap is not None

# check mean and std and create normalize transform
if (
self.prediction_config.image_mean is None
or self.prediction_config.image_std is None
):
raise ValueError("Mean and std must be provided for prediction.")
else:
self.mean = self.prediction_config.image_mean
self.std = self.prediction_config.image_std

# instantiate normalize transform
self.patch_transform = Compose(
transform_list=[
NormalizeModel(
image_means=prediction_config.image_mean,
image_stds=prediction_config.image_std,
)
],
)

def __iter__(
self,
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
"""
Iterate over data source and yield single patch.
Yields
------
np.ndarray
Single patch.
"""
assert (
self.image_means is not None and self.image_stds is not None
), "Mean and std must be provided"

for sample, _ in _iterate_over_files(
self.prediction_config,
self.data_files,
read_source_func=self.read_source_func,
):
# reshape array
reshaped_sample = reshape_array(sample, self.axes)

if (
self.tile
and self.tile_size is not None
and self.tile_overlap is not None
):
# generate patches, return a generator
patch_gen = extract_tiles(
arr=reshaped_sample,
tile_size=self.tile_size,
overlaps=self.tile_overlap,
)
else:
# just wrap the sample in a generator with default tiling info
array_shape = reshaped_sample.squeeze().shape
patch_gen = (
(reshaped_sample, TileInformation(array_shape=array_shape))
for _ in range(1)
)

# apply transform to patches
for patch_array, tile_info in patch_gen:
transformed_patch, _ = self.patch_transform(patch=patch_array)

yield transformed_patch, tile_info
32 changes: 18 additions & 14 deletions src/careamics/dataset/iterable_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,24 @@ def __init__(
self.read_source_func = read_source_func

# check mean and std and create normalize transform
if self.prediction_config.mean is None or self.prediction_config.std is None:
if (
self.prediction_config.image_mean is None
or self.prediction_config.image_std is None
):
raise ValueError("Mean and std must be provided for prediction.")
else:
self.mean = self.prediction_config.mean
self.std = self.prediction_config.std

# instantiate normalize transform
self.patch_transform = Compose(
transform_list=[
NormalizeModel(
mean=prediction_config.mean, std=prediction_config.std
)
],
)
self.image_means = self.prediction_config.image_mean
self.image_stds = self.prediction_config.image_std

# instantiate normalize transform
self.patch_transform = Compose(
transform_list=[
NormalizeModel(
image_means=self.image_means,
image_stds=self.image_stds,
)
],
)

def __iter__(
self,
Expand All @@ -101,7 +105,7 @@ def __iter__(
Single patch.
"""
assert (
self.mean is not None and self.std is not None
self.image_means is not None and self.image_stds is not None
), "Mean and std must be provided"

for sample, _ in iterate_over_files(
Expand All @@ -112,6 +116,6 @@ def __iter__(
# sample has S dimension
for i in range(sample.shape[0]):

transformed_sample, _ = self.patch_transform(patch=sample[[i]])
transformed_sample, _ = self.patch_transform(patch=sample[i])

yield transformed_sample
14 changes: 9 additions & 5 deletions src/careamics/dataset/iterable_tiled_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,21 @@ def __init__(
self.read_source_func = read_source_func

# check mean and std and create normalize transform
if self.prediction_config.mean is None or self.prediction_config.std is None:
if (
self.prediction_config.image_mean is None
or self.prediction_config.image_std is None
):
raise ValueError("Mean and std must be provided for prediction.")
else:
self.mean = self.prediction_config.mean
self.std = self.prediction_config.std
self.image_means = self.prediction_config.image_mean
self.image_stds = self.prediction_config.image_std

# instantiate normalize transform
self.patch_transform = Compose(
transform_list=[
NormalizeModel(
mean=prediction_config.mean, std=prediction_config.std
image_means=self.image_means,
image_stds=self.image_stds,
)
],
)
Expand All @@ -113,7 +117,7 @@ def __iter__(
Single tile.
"""
assert (
self.mean is not None and self.std is not None
self.image_means is not None and self.image_stds is not None
), "Mean and std must be provided"

for sample, _ in iterate_over_files(
Expand Down
Loading

0 comments on commit b8e382f

Please sign in to comment.