Skip to content

Commit

Permalink
(fix): fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 7, 2024
1 parent f2d8933 commit 0f9e62e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/careamics/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ def __init__(
# get transforms
self.patch_transform = Compose(transform_list=data_config.transforms)

def _calculate_mean_and_std(self) -> Tuple[float, float]:
def _calculate_mean_and_std(self) -> PatchedOutput:
"""
Calculate mean and std of the dataset.
Returns
-------
Tuple[float, float]
Tuple containing mean and standard deviation.
PatchedOutput
Data class containing the image statistics.
"""
image_means, image_stds, target_means, target_stds = 0, 0, 0, 0
num_samples = 0
Expand Down
11 changes: 5 additions & 6 deletions src/careamics/prediction/stitch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
from typing import List

import numpy as np
import torch


def stitch_prediction(
tiles: List[torch.Tensor],
stitching_data: List[List[torch.Tensor]],
) -> torch.Tensor:
tiles: List[np.ndarray],
stitching_data: List[List[np.ndarray]],
) -> np.ndarray:
"""
Stitch tiles back together to form a full image.
Parameters
----------
tiles : List[torch.Tensor]
tiles : List[np.ndarray]
Cropped tiles and their respective stitching coordinates.
stitching_data : List
stitching_coords : List[List[np.ndarray]]
List of information and coordinates obtained from
`dataset.tiled_patching.extract_tiles`.
Expand Down
6 changes: 3 additions & 3 deletions tests/prediction/test_stitch_prediction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from torch import from_numpy, tensor
from torch import tensor

from careamics.dataset.patching.tiled_patching import extract_tiles
from careamics.prediction.stitch_prediction import stitch_prediction
Expand All @@ -26,7 +26,7 @@ def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps):

# Assemble all tiles as it is done during the prediction stage
for tile_data, tile_info in tile_generator:
tiles.append(from_numpy(tile_data)) # need to convert to torch.Tensor
tiles.append(tile_data)
stitching_data.append(
( # this is way too wacky
[tensor(i) for i in input_shape], # need to convert to torch.Tensor
Expand All @@ -38,4 +38,4 @@ def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps):
# compute stitching coordinates, it returns a torch.Tensor
result = stitch_prediction(tiles, stitching_data)

assert (result.numpy() == arr).all()
assert (result == arr).all()
6 changes: 3 additions & 3 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def random_array(shape: Tuple[int, ...]):
"""Return a random array with values between 0 and 255."""
return (255 * (1 + np.random.rand(*shape)) / 2).astype(np.float32)
return (np.random.randint(0, 255, shape)).astype(np.float32)


def test_no_parameters():
Expand Down Expand Up @@ -653,7 +653,7 @@ def test_data_for_bmz_random(tmp_path, minimum_configuration):
config.data_config.data_type = SupportedData.ARRAY.value
config.data_config.patch_size = (8, 8)
config.data_config.set_mean_and_std(
image_mean=example_data.mean(), image_std=example_data.std()
image_mean=[example_data.mean()], image_std=[example_data.std()]
)

# instantiate CAREamist
Expand Down Expand Up @@ -684,7 +684,7 @@ def test_data_for_bmz_with_array(tmp_path, minimum_configuration):
config.data_config.data_type = SupportedData.ARRAY.value
config.data_config.patch_size = (8, 8)
config.data_config.set_mean_and_std(
image_mean=example_data.mean(), image_std=example_data.std()
image_mean=[example_data.mean()], image_std=[example_data.std()]
)

# instantiate CAREamist
Expand Down

0 comments on commit 0f9e62e

Please sign in to comment.