Skip to content

Commit

Permalink
feat(type hints): add param specs to PatchSpecGenerator protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
melisande-c committed Jan 28, 2025
1 parent 24a0013 commit 79cffe6
Showing 1 changed file with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
from collections.abc import Sequence
from typing import Protocol
from typing import ParamSpec, Protocol

import numpy as np
from numpy.typing import NDArray

from ..patch_extractor import PatchSpecs

P = ParamSpec("P")

class PatchSpecsGenerator(Protocol):

class PatchSpecsGenerator(Protocol[P]):

def generate(
self, patch_size: Sequence[int], *args, **kwargs
self, patch_size: Sequence[int], *args: P.args, **kwargs: P.kwargs
) -> list[PatchSpecs]: ...

# Should return the number of patches that will be produced for a set of args
# Will be for mapped dataset length
def n_patches(self, patch_size: Sequence[int], *args, **kwargs): ...
def n_patches(
self, patch_size: Sequence[int], *args: P.args, **kwargs: P.kwargs
) -> int: ...


class RandomPatchSpecsGenerator:

def __init__(self, data_shapes: Sequence[Sequence[int]]):
self.data_shapes = data_shapes

def generate(self, patch_size: Sequence[int], seed: int):
def generate(self, patch_size: Sequence[int], seed: int) -> list[PatchSpecs]:
rng = np.random.default_rng(seed=seed)
patch_specs: list[PatchSpecs] = []
for data_idx, data_shape in enumerate(self.data_shapes):
Expand Down Expand Up @@ -50,8 +55,10 @@ def generate(self, patch_size: Sequence[int], seed: int):
patch_specs.extend(data_patch_specs)
return patch_specs

def n_patches(self, patch_size: Sequence[int], seed: int):
n_sample_patches = np.array(
# NOTE: enerate and n_patches methods must have matching signatures
# as dictated by protocol
def n_patches(self, patch_size: Sequence[int], seed: int) -> int:
n_sample_patches: NDArray[np.int_] = np.array(
[
self._n_patches_in_sample(patch_size, data_shape[-len(patch_size) :])
for data_shape in self.data_shapes
Expand All @@ -62,10 +69,12 @@ def n_patches(self, patch_size: Sequence[int], seed: int):
[data_shape[0] for data_shape in self.data_shapes], dtype=int
)
n_data_patches = n_samples * n_sample_patches
return n_data_patches.sum()
return int(n_data_patches.sum())

@staticmethod
def _n_patches_in_sample(patch_size: Sequence[int], spatial_shape: Sequence[int]):
def _n_patches_in_sample(
patch_size: Sequence[int], spatial_shape: Sequence[int]
) -> int:
if len(patch_size) != len(spatial_shape):
raise ValueError(
"Number of patch dimension do not match the number of spatial "
Expand Down

0 comments on commit 79cffe6

Please sign in to comment.