Skip to content

Commit

Permalink
feat(patch extractor factory): remove unecessary kwarg builder func
Browse files Browse the repository at this point in the history
  • Loading branch information
melisande-c committed Feb 3, 2025
1 parent f6333e6 commit 2c15dde
Showing 1 changed file with 6 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Union
from typing import Optional, Union

from numpy.typing import NDArray

Expand All @@ -12,19 +12,6 @@
)


def build_patch_extractor_constructor_kwargs(
data_config: GeneralDataConfig, **custom_kwargs: Any
) -> dict: # TODO: return union of TypedDicts?
if data_config.data_type == SupportedData.ARRAY:
return {"axes": data_config.axes}
elif data_config.data_type == SupportedData.TIFF:
return {"axes": data_config.axes}
elif data_config.data_type == SupportedData.CUSTOM:
return {"axes": data_config.axes, **custom_kwargs}
else:
raise ValueError(f"Data type {data_config.data_type} is not supported.")


def get_patch_extractor_constructor(
data_config: GeneralDataConfig,
) -> PatchExtractorConstructor:
Expand All @@ -44,7 +31,7 @@ def create_patch_extractors(
val_data: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None,
train_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None,
val_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None,
**custom_kwargs,
**kwargs,
) -> tuple[
PatchExtractor,
Optional[PatchExtractor],
Expand All @@ -56,14 +43,13 @@ def create_patch_extractors(
constructor = get_patch_extractor_constructor(data_config)

# build key word args
constructor_kwargs = build_patch_extractor_constructor_kwargs(
data_config, **custom_kwargs
)
# --- train images
constructor_kwargs = {"axes": data_config.axes, **kwargs}

# --- train data extractor
train_patch_extractor: PatchExtractor = constructor(
source=train_data, **constructor_kwargs
)

# --- additional data extractors
additional_patch_extractors: list[Union[PatchExtractor, None]] = []
additional_data_sources = [val_data, train_data_target, val_data_target]
for data_source in additional_data_sources:
Expand Down

0 comments on commit 2c15dde

Please sign in to comment.