diff --git a/src/careamics/config/configuration_factories.py b/src/careamics/config/configuration_factories.py index c28761438..c444d14a9 100644 --- a/src/careamics/config/configuration_factories.py +++ b/src/careamics/config/configuration_factories.py @@ -2,15 +2,17 @@ from typing import Annotated, Any, Literal, Optional, Union -from pydantic import Field, TypeAdapter +from pydantic import Discriminator, Field, Tag, TypeAdapter from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm from careamics.config.architectures import UNetModel from careamics.config.care_configuration import CAREConfiguration +from careamics.config.configuration import Configuration from careamics.config.data import DataConfig from careamics.config.n2n_configuration import N2NConfiguration from careamics.config.n2v_configuration import N2VConfiguration from careamics.config.support import ( + SupportedAlgorithm, SupportedArchitecture, SupportedPixelManipulation, SupportedTransform, @@ -24,6 +26,24 @@ ) +def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str: + """Discriminate algorithm-specific configurations based on the algorithm. + + Parameters + ---------- + value : Any + Value to discriminate. + + Returns + ------- + str + Discriminator value. + """ + if isinstance(value, dict): + return value["algorithm_config"]["algorithm"] + return value.algorithm_config.algorithm + + def configuration_factory( configuration: dict[str, Any] ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]: @@ -41,7 +61,14 @@ def configuration_factory( Configuration for training CAREamics. """ adapter: TypeAdapter = TypeAdapter( - Union[N2VConfiguration, N2NConfiguration, CAREConfiguration] + Annotated[ + Union[ + Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)], + Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)], + Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)], + ], + Discriminator(_algorithm_config_discriminator), + ] ) return adapter.validate_python(configuration) diff --git a/src/careamics/config/support/supported_algorithms.py b/src/careamics/config/support/supported_algorithms.py index dc5752bd5..15b30274b 100644 --- a/src/careamics/config/support/supported_algorithms.py +++ b/src/careamics/config/support/supported_algorithms.py @@ -6,7 +6,11 @@ class SupportedAlgorithm(str, BaseEnum): - """Algorithms available in CAREamics.""" + """Algorithms available in CAREamics. + + These definitions are the same as the keyword `name` of the algorithm + configurations. + """ N2V = "n2v" """Noise2Void algorithm, a self-supervised approach based on blind denoising.""" diff --git a/tests/config/test_configuration_factories.py b/tests/config/test_configuration_factories.py index 5ec36bb10..87ed40e19 100644 --- a/tests/config/test_configuration_factories.py +++ b/tests/config/test_configuration_factories.py @@ -13,12 +13,14 @@ create_n2v_configuration, ) from careamics.config.configuration_factories import ( + _algorithm_config_discriminator, _create_supervised_config_dict, _create_unet_configuration, _list_spatial_augmentations, configuration_factory, ) from careamics.config.support import ( + SupportedAlgorithm, SupportedPixelManipulation, SupportedStructAxis, SupportedTransform, @@ -30,13 +32,33 @@ ) +def test_algorithm_discriminator_n2v(minimum_n2v_configuration): + """Test that the N2V configuration is discriminated correctly.""" + tag = _algorithm_config_discriminator(minimum_n2v_configuration) + assert tag == SupportedAlgorithm.N2V.value + + +@pytest.mark.parametrize( + "algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value] +) +def test_algorithm_discriminator_supervised( + minimum_supervised_configuration, algorithm +): + """Test that the supervised configuration is discriminated correctly.""" + minimum_supervised_configuration["algorithm_config"]["algorithm"] = algorithm + tag = _algorithm_config_discriminator(minimum_supervised_configuration) + assert tag == algorithm + + def test_careamics_config_n2v(minimum_n2v_configuration): """Test that the N2V configuration is created correctly.""" configuration = configuration_factory(minimum_n2v_configuration) assert isinstance(configuration, N2VConfiguration) -@pytest.mark.parametrize("algorithm", ["n2n", "care"]) +@pytest.mark.parametrize( + "algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value] +) def test_careamics_config_supervised(minimum_supervised_configuration, algorithm): """Test that the supervised configuration is created correctly.""" min_config = minimum_supervised_configuration