diff --git a/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py b/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py index 4603179f..bd9b727c 100644 --- a/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py +++ b/src/careamics/dataset_ng/patching_strategies/patch_specs_generator.py @@ -1,20 +1,25 @@ from collections.abc import Sequence -from typing import Protocol +from typing import ParamSpec, Protocol import numpy as np +from numpy.typing import NDArray from ..patch_extractor import PatchSpecs +P = ParamSpec("P") -class PatchSpecsGenerator(Protocol): + +class PatchSpecsGenerator(Protocol[P]): def generate( - self, patch_size: Sequence[int], *args, **kwargs + self, patch_size: Sequence[int], *args: P.args, **kwargs: P.kwargs ) -> list[PatchSpecs]: ... # Should return the number of patches that will be produced for a set of args # Will be for mapped dataset length - def n_patches(self, patch_size: Sequence[int], *args, **kwargs): ... + def n_patches( + self, patch_size: Sequence[int], *args: P.args, **kwargs: P.kwargs + ) -> int: ... class RandomPatchSpecsGenerator: @@ -22,7 +27,7 @@ class RandomPatchSpecsGenerator: def __init__(self, data_shapes: Sequence[Sequence[int]]): self.data_shapes = data_shapes - def generate(self, patch_size: Sequence[int], seed: int): + def generate(self, patch_size: Sequence[int], seed: int) -> list[PatchSpecs]: rng = np.random.default_rng(seed=seed) patch_specs: list[PatchSpecs] = [] for data_idx, data_shape in enumerate(self.data_shapes): @@ -50,8 +55,10 @@ def generate(self, patch_size: Sequence[int], seed: int): patch_specs.extend(data_patch_specs) return patch_specs - def n_patches(self, patch_size: Sequence[int], seed: int): - n_sample_patches = np.array( + # NOTE: enerate and n_patches methods must have matching signatures + # as dictated by protocol + def n_patches(self, patch_size: Sequence[int], seed: int) -> int: + n_sample_patches: NDArray[np.int_] = np.array( [ self._n_patches_in_sample(patch_size, data_shape[-len(patch_size) :]) for data_shape in self.data_shapes @@ -62,10 +69,12 @@ def n_patches(self, patch_size: Sequence[int], seed: int): [data_shape[0] for data_shape in self.data_shapes], dtype=int ) n_data_patches = n_samples * n_sample_patches - return n_data_patches.sum() + return int(n_data_patches.sum()) @staticmethod - def _n_patches_in_sample(patch_size: Sequence[int], spatial_shape: Sequence[int]): + def _n_patches_in_sample( + patch_size: Sequence[int], spatial_shape: Sequence[int] + ) -> int: if len(patch_size) != len(spatial_shape): raise ValueError( "Number of patch dimension do not match the number of spatial "