Skip to content

Commit

Permalink
(feat): add see to random patching
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 7, 2024
1 parent eb6f44f commit 456f4cf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 66 deletions.
13 changes: 9 additions & 4 deletions src/careamics/dataset/patching/random_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def extract_patches_random(
arr: np.ndarray,
patch_size: Union[List[int], Tuple[int, ...]],
target: Optional[np.ndarray] = None,
seed: Optional[int] = None,
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
"""
Generate patches from an array in a random manner.
Expand All @@ -34,12 +35,16 @@ def extract_patches_random(
Patch sizes in each dimension.
target : Optional[np.ndarray], optional
Target array, by default None.
seed : Optional[int], optional
Random seed, by default None.
Yields
------
Generator[np.ndarray, None, None]
Generator of patches.
"""
rng = np.random.default_rng(seed=seed)

is_3d_patch = len(patch_size) == 3

# patches sanity check
Expand All @@ -48,9 +53,6 @@ def extract_patches_random(
# Update patch size to encompass S and C dimensions
patch_size = [1, arr.shape[1], *patch_size]

# random generator
rng = np.random.default_rng()

# iterate over the number of samples (S or T)
for sample_idx in range(arr.shape[0]):
# get sample array
Expand Down Expand Up @@ -113,6 +115,7 @@ def extract_patches_random_from_chunks(
patch_size: Union[List[int], Tuple[int, ...]],
chunk_size: Union[List[int], Tuple[int, ...]],
chunk_limit: Optional[int] = None,
seed: Optional[int] = None,
) -> Generator[np.ndarray, None, None]:
"""
Generate patches from an array in a random manner.
Expand All @@ -130,6 +133,8 @@ def extract_patches_random_from_chunks(
Chunk sizes to load from the.
chunk_limit : Optional[int], optional
Number of chunks to load, by default None.
seed : Optional[int], optional
Random seed, by default None.
Yields
------
Expand All @@ -141,7 +146,7 @@ def extract_patches_random_from_chunks(
# Patches sanity check
validate_patch_dimensions(arr, patch_size, is_3d_patch)

rng = np.random.default_rng()
rng = np.random.default_rng(seed=seed)
num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)

# Iterate over num chunks in the array
Expand Down
76 changes: 17 additions & 59 deletions tests/dataset/patching/test_random_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@ def test_random_patching_unsupervised(ordered_array, shape, patch_size):
Since extract patches is called on already shaped array, dimensions S and C are
present.
"""
np.random.seed(42)

# create array
array = ordered_array(shape)
is_3D = len(patch_size) == 3
top_left = []

for _ in range(3):
patch_generator = extract_patches_random(array, patch_size=patch_size)
# minimum number of unique patches to validate randomness
if is_3D:
min_unique_patches = 0.7 * np.prod(shape[-3:]) / np.prod(patch_size)
else:
min_unique_patches = 0.7 * np.prod(shape[-2:]) / np.prod(patch_size)

top_left = []
seeds = [24, 42, 56]
for i in range(3):
patch_generator = extract_patches_random(
array, patch_size=patch_size, seed=seeds[i]
)

# get all patches and targets
patches = [patch for patch, _ in patch_generator]
Expand All @@ -49,57 +56,8 @@ def test_random_patching_unsupervised(ordered_array, shape, patch_size):

top_left.append(np.array(ind))

# check randomness
coords = np.array(top_left).squeeze()
assert coords.min() == 0
assert coords.max() == max(array.shape) - max(patch_size)
assert len(np.unique(coords, axis=0)) >= 0.7 * np.prod(shape) / np.prod(patch_size)


# @pytest.mark.parametrize(
# "patch_size",
# [
# (2, 2),
# (4, 2),
# (4, 8),
# (8, 8),
# ],
# )
# def test_extract_patches_random_2d(array_2D, patch_size):
# """Test extracting patches randomly in 2D."""
# check_extract_patches_random(array_2D, "SYX", patch_size)


# @pytest.mark.parametrize(
# "patch_size",
# [
# (2, 2),
# (4, 2),
# (4, 8),
# (8, 8),
# ],
# )
# def test_extract_patches_random_supervised_2d(array_2D, patch_size):
# """Test extracting patches randomly in 2D."""
# check_extract_patches_random(
# array_2D,
# "SYX",
# patch_size,
# target=array_2D
# )


# @pytest.mark.parametrize(
# "patch_size",
# [
# (2, 2, 4),
# (4, 2, 2),
# (2, 8, 4),
# (4, 8, 8),
# ],
# )
# def test_extract_patches_random_3d(array_3D, patch_size):
# """Test extracting patches randomly in 3D.

# The 3D array is a fixture of shape (1, 8, 16, 16)."""
# check_extract_patches_random(array_3D, "SZYX", patch_size)
# check randomness
coords = np.array(top_left).squeeze()
assert coords.min() == 0
assert coords.max() == max(array.shape) - max(patch_size)
assert len(np.unique(coords, axis=0)) >= min_unique_patches
18 changes: 15 additions & 3 deletions tests/transforms/test_pixel_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def test_get_stratified_coords(mask_pixel_perc, shape, num_iterations):
Ensure that the array of coordinates is randomly distributed across the
image and that most pixels get selected.
"""
rng = np.random.default_rng(42)

# Define the dummy array
array = np.zeros(shape)

Expand All @@ -28,7 +30,7 @@ def test_get_stratified_coords(mask_pixel_perc, shape, num_iterations):
# biased towards any particular region.
for _ in range(num_iterations):
# Get the coordinates of the pixels to be masked
coords = _get_stratified_coords(mask_pixel_perc, shape)
coords = _get_stratified_coords(mask_pixel_perc, shape, rng)

# Check that there is at least one coordinate choosen
assert len(coords) > 0
Expand Down Expand Up @@ -59,12 +61,14 @@ def test_uniform_manipulate(ordered_array, shape):
Ensures that the mask corresponds to the manipulated pixels, and that the
manipulated pixels have a value taken from a ROI surrounding them.
"""
rng = np.random.default_rng(42)

# create the array
patch = ordered_array(shape)

# manipulate the array
transform_patch, mask = uniform_manipulate(
patch, mask_pixel_percentage=10, subpatch_size=5
patch, mask_pixel_percentage=10, subpatch_size=5, rng=rng
)

# find pixels that have different values between patch and transformed patch
Expand Down Expand Up @@ -104,12 +108,14 @@ def test_median_manipulate(ordered_array, shape):
Ensures that the mask corresponds to the manipulated pixels, and that the
manipulated pixels have a value taken from a ROI surrounding them.
"""
rng = np.random.default_rng(42)

# create the array
patch = ordered_array(shape).astype(np.float32)

# manipulate the array
transform_patch, mask = median_manipulate(
patch, subpatch_size=5, mask_pixel_percentage=10
patch, subpatch_size=5, mask_pixel_percentage=10, rng=rng
)

# find pixels that have different values between patch and transformed patch
Expand Down Expand Up @@ -152,6 +158,8 @@ def test_apply_struct_mask(coords, struct_axis, struct_span):
Ensures that the mask corresponds to the manipulated pixels, and that the
manipulated pixels have a value taken from a ROI surrounding them.
"""
rng = np.random.default_rng(42)

struct_params = StructMaskParameters(axis=struct_axis, span=struct_span)

# create array
Expand All @@ -172,6 +180,7 @@ def test_apply_struct_mask(coords, struct_axis, struct_span):
patch,
coords=coords,
struct_params=struct_params,
rng=rng,
)
changed_values = patch[np.where(original_patch != transform_patch)]

Expand Down Expand Up @@ -215,6 +224,8 @@ def test_apply_struct_mask_3D(coords, struct_axis, struct_span):
Ensures that the mask corresponds to the manipulated pixels, and that the
manipulated pixels have a value taken from a ROI surrounding them.
"""
rng = np.random.default_rng(42)

struct_params = StructMaskParameters(axis=struct_axis, span=struct_span)

# create array
Expand All @@ -235,6 +246,7 @@ def test_apply_struct_mask_3D(coords, struct_axis, struct_span):
patch,
coords=coords,
struct_params=struct_params,
rng=rng,
)
changed_values = patch[np.where(original_patch != transform_patch)]

Expand Down

0 comments on commit 456f4cf

Please sign in to comment.