Skip to content

Commit

Permalink
feat: Improve Pydantic configuration discrimination (#366)
Browse files Browse the repository at this point in the history
### Description

Did I mention I love Pydantic?

The discrimination between algorithm-specific configurations by Pydantic
generates a lot of errors when failing
(#356). This can be avoided
using a `Discriminator` in the `Union`.

### Changes Made

- **Modified**: Configuration factory module.


### Related Issues

- Resolves #356


### Breaking changes

No change.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Jan 22, 2025
1 parent cf3a094 commit 9a027a5
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
32 changes: 29 additions & 3 deletions src/careamics/config/configuration_factories.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Convenience functions to create configurations for training and inference."""

from typing import Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, Union

from pydantic import TypeAdapter
from pydantic import Discriminator, Tag, TypeAdapter

from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
from careamics.config.architectures import UNetModel
Expand All @@ -12,6 +12,7 @@
from careamics.config.n2n_configuration import N2NConfiguration
from careamics.config.n2v_configuration import N2VConfiguration
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
SupportedPixelManipulation,
SupportedTransform,
Expand All @@ -26,6 +27,24 @@
)


def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
"""Discriminate algorithm-specific configurations based on the algorithm.
Parameters
----------
value : Any
Value to discriminate.
Returns
-------
str
Discriminator value.
"""
if isinstance(value, dict):
return value["algorithm_config"]["algorithm"]
return value.algorithm_config.algorithm


def configuration_factory(
configuration: dict[str, Any]
) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
Expand All @@ -43,7 +62,14 @@ def configuration_factory(
Configuration for training CAREamics.
"""
adapter: TypeAdapter = TypeAdapter(
Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
Annotated[
Union[
Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
],
Discriminator(_algorithm_config_discriminator),
]
)
return adapter.validate_python(configuration)

Expand Down
6 changes: 5 additions & 1 deletion src/careamics/config/support/supported_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@


class SupportedAlgorithm(str, BaseEnum):
"""Algorithms available in CAREamics."""
"""Algorithms available in CAREamics.
These definitions are the same as the keyword `name` of the algorithm
configurations.
"""

N2V = "n2v"
"""Noise2Void algorithm, a self-supervised approach based on blind denoising."""
Expand Down
24 changes: 23 additions & 1 deletion tests/config/test_configuration_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
create_n2v_configuration,
)
from careamics.config.configuration_factories import (
_algorithm_config_discriminator,
_create_configuration,
_create_supervised_configuration,
_create_unet_configuration,
Expand All @@ -23,6 +24,7 @@
)
from careamics.config.data import N2VDataConfig
from careamics.config.support import (
SupportedAlgorithm,
SupportedPixelManipulation,
SupportedStructAxis,
SupportedTransform,
Expand All @@ -34,13 +36,33 @@
)


def test_algorithm_discriminator_n2v(minimum_n2v_configuration):
"""Test that the N2V configuration is discriminated correctly."""
tag = _algorithm_config_discriminator(minimum_n2v_configuration)
assert tag == SupportedAlgorithm.N2V.value


@pytest.mark.parametrize(
"algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value]
)
def test_algorithm_discriminator_supervised(
minimum_supervised_configuration, algorithm
):
"""Test that the supervised configuration is discriminated correctly."""
minimum_supervised_configuration["algorithm_config"]["algorithm"] = algorithm
tag = _algorithm_config_discriminator(minimum_supervised_configuration)
assert tag == algorithm


def test_careamics_config_n2v(minimum_n2v_configuration):
"""Test that the N2V configuration is created correctly."""
configuration = configuration_factory(minimum_n2v_configuration)
assert isinstance(configuration, N2VConfiguration)


@pytest.mark.parametrize("algorithm", ["n2n", "care"])
@pytest.mark.parametrize(
"algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value]
)
def test_careamics_config_supervised(minimum_supervised_configuration, algorithm):
"""Test that the supervised configuration is created correctly."""
min_config = minimum_supervised_configuration
Expand Down

0 comments on commit 9a027a5

Please sign in to comment.