diff --git a/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py b/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py index 8c8c64cd..fc68c65c 100644 --- a/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +++ b/src/careamics/dataset_ng/patch_extractor/patch_extractor_factory.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Optional, Union +from typing import Optional, Union from numpy.typing import NDArray @@ -12,19 +12,6 @@ ) -def build_patch_extractor_constructor_kwargs( - data_config: GeneralDataConfig, **custom_kwargs: Any -) -> dict: # TODO: return union of TypedDicts? - if data_config.data_type == SupportedData.ARRAY: - return {"axes": data_config.axes} - elif data_config.data_type == SupportedData.TIFF: - return {"axes": data_config.axes} - elif data_config.data_type == SupportedData.CUSTOM: - return {"axes": data_config.axes, **custom_kwargs} - else: - raise ValueError(f"Data type {data_config.data_type} is not supported.") - - def get_patch_extractor_constructor( data_config: GeneralDataConfig, ) -> PatchExtractorConstructor: @@ -44,7 +31,7 @@ def create_patch_extractors( val_data: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None, train_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None, val_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None, - **custom_kwargs, + **kwargs, ) -> tuple[ PatchExtractor, Optional[PatchExtractor], @@ -56,14 +43,13 @@ def create_patch_extractors( constructor = get_patch_extractor_constructor(data_config) # build key word args - constructor_kwargs = build_patch_extractor_constructor_kwargs( - data_config, **custom_kwargs - ) - # --- train images + constructor_kwargs = {"axes": data_config.axes, **kwargs} + + # --- train data extractor train_patch_extractor: PatchExtractor = constructor( source=train_data, **constructor_kwargs ) - + # --- additional data extractors additional_patch_extractors: list[Union[PatchExtractor, None]] = [] additional_data_sources = [val_data, train_data_target, val_data_target] for data_source in additional_data_sources: