Skip to content

Commit

Permalink
feat: Add algorithm model validators (#361)
Browse files Browse the repository at this point in the history
### Description

Following #354:
`final_activation` was allowed to be set although it would break CARE,
N2N and N2V.

This PR adds a Pydantic validator preventing final activation for these
algorithms. It also takes the opportunity to improve the Pydantic models
by moving the validators to a different module and using annotations for
cleaner code.

- **What**: CARE, N2N and N2V do not allow final activation in the UNet.
- **Why**: This would otherwise break the algorithm training.
- **How**: Add model validators in the respective Pydantic models.

### Changes Made

- **Added**: `model_validators.py`.
- **Modified**: Algorithms model.

### Related Issues

- Fixes #354

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Jan 20, 2025
1 parent b5b40a6 commit aad0c19
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 73 deletions.
36 changes: 12 additions & 24 deletions src/careamics/config/algorithms/care_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""CARE algorithm configuration."""

from typing import Literal
from typing import Annotated, Literal

from pydantic import field_validator
from pydantic import AfterValidator

from careamics.config.architectures import UNetModel
from careamics.config.validators import (
model_without_final_activation,
model_without_n2v2,
)

from .unet_algorithm_model import UNetBasedAlgorithm

Expand All @@ -26,25 +30,9 @@ class CAREAlgorithm(UNetBasedAlgorithm):
loss: Literal["mae", "mse"] = "mae"
"""CARE-compatible loss function."""

@classmethod
@field_validator("model")
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
"""Validate that the model does not have the n2v2 attribute.
Parameters
----------
value : UNetModel
Model to validate.
Returns
-------
UNetModel
The validated model.
"""
if value.n2v2:
raise ValueError(
"The N2N algorithm does not support the `n2v2` attribute. "
"Set it to `False`."
)

return value
model: Annotated[
UNetModel,
AfterValidator(model_without_n2v2),
AfterValidator(model_without_final_activation),
]
"""UNet without a final activation function and without the `n2v2` modifications."""
38 changes: 13 additions & 25 deletions src/careamics/config/algorithms/n2n_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,30 @@
"""N2N Algorithm configuration."""

from typing import Literal
from typing import Annotated, Literal

from pydantic import field_validator
from pydantic import AfterValidator

from careamics.config.architectures import UNetModel
from careamics.config.validators import (
model_without_final_activation,
model_without_n2v2,
)

from .unet_algorithm_model import UNetBasedAlgorithm


class N2NAlgorithm(UNetBasedAlgorithm):
"""N2N Algorithm configuration."""
"""Noise2Noise Algorithm configuration."""

algorithm: Literal["n2n"] = "n2n"
"""N2N Algorithm name."""

loss: Literal["mae", "mse"] = "mae"
"""N2N-compatible loss function."""

@classmethod
@field_validator("model")
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
"""Validate that the model does not have the n2v2 attribute.
Parameters
----------
value : UNetModel
Model to validate.
Returns
-------
UNetModel
The validated model.
"""
if value.n2v2:
raise ValueError(
"The N2N algorithm does not support the `n2v2` attribute. "
"Set it to `False`."
)

return value
model: Annotated[
UNetModel,
AfterValidator(model_without_n2v2),
AfterValidator(model_without_final_activation),
]
"""UNet without a final activation function and without the `n2v2` modifications."""
32 changes: 13 additions & 19 deletions src/careamics/config/algorithms/n2v_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
""""N2V Algorithm configuration."""

from typing import Literal
from typing import Annotated, Literal

from pydantic import model_validator
from typing_extensions import Self
from pydantic import AfterValidator

from careamics.config.architectures import UNetModel
from careamics.config.validators import (
model_matching_in_out_channels,
model_without_final_activation,
)

from .unet_algorithm_model import UNetBasedAlgorithm

Expand All @@ -17,19 +22,8 @@ class N2VAlgorithm(UNetBasedAlgorithm):
loss: Literal["n2v"] = "n2v"
"""N2V loss function."""

@model_validator(mode="after")
def algorithm_cross_validation(self: Self) -> Self:
"""Validate the algorithm model for N2V.
Returns
-------
Self
The validated model.
"""
if self.model.in_channels != self.model.num_classes:
raise ValueError(
"N2V requires the same number of input and output channels. Make "
"sure that `in_channels` and `num_classes` are the same."
)

return self
model: Annotated[
UNetModel,
AfterValidator(model_matching_in_out_channels),
AfterValidator(model_without_final_activation),
]
13 changes: 12 additions & 1 deletion src/careamics/config/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
"""Validator utilities."""

__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"]
__all__ = [
"check_axes_validity",
"model_matching_in_out_channels",
"model_without_final_activation",
"model_without_n2v2",
"patch_size_ge_than_8_power_of_2",
]

from .model_validators import (
model_matching_in_out_channels,
model_without_final_activation,
model_without_n2v2,
)
from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2
84 changes: 84 additions & 0 deletions src/careamics/config/validators/model_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Architecture model validators."""

from careamics.config.architectures import UNetModel


def model_without_n2v2(model: UNetModel) -> UNetModel:
"""Validate that the Unet model does not have the n2v2 attribute.
Parameters
----------
model : UNetModel
Model to validate.
Returns
-------
UNetModel
The validated model.
Raises
------
ValueError
If the model has the `n2v2` attribute set to `True`.
"""
if model.n2v2:
raise ValueError(
"The algorithm does not support the `n2v2` attribute in the model. "
"Set it to `False`."
)

return model


def model_without_final_activation(model: UNetModel) -> UNetModel:
"""Validate that the UNet model does not have the final_activation.
Parameters
----------
model : UNetModel
Model to validate.
Returns
-------
UNetModel
The validated model.
Raises
------
ValueError
If the model has the final_activation attribute set.
"""
if model.final_activation != "None":
raise ValueError(
"The algorithm does not support a `final_activation` in the model. "
'Set it to `"None"`.'
)

return model


def model_matching_in_out_channels(model: UNetModel) -> UNetModel:
"""Validate that the UNet model has the same number of channel inputs and outputs.
Parameters
----------
model : UNetModel
Model to validate.
Returns
-------
UNetModel
Validated model.
Raises
------
ValueError
If the model has different number of input and output channels.
"""
if model.num_classes != model.in_channels:
raise ValueError(
"The algorithm requires the same number of input and output channels. "
"Make sure that `in_channels` and `num_classes` are equal."
)

return model
32 changes: 32 additions & 0 deletions tests/config/algorithms/test_care_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from careamics.config.algorithms import CAREAlgorithm


def test_instantiation():
"""Test the instantiation of the CAREAlgorithm class."""
model = {
"architecture": "UNet",
}
CAREAlgorithm(model=model)


def test_no_n2v2():
"""Check that an error is raised if the model is set for n2v2."""
model = {
"architecture": "UNet",
"n2v2": True,
}

with pytest.raises(ValueError):
CAREAlgorithm(model=model)


def test_no_final_activation(minimum_algorithm_supervised: dict):
"""Check that an error is raised if the model has a final activation."""
minimum_algorithm_supervised["model"] = {
"architecture": "UNet",
"final_activation": "ReLU",
}
with pytest.raises(ValueError):
CAREAlgorithm(**minimum_algorithm_supervised)
32 changes: 32 additions & 0 deletions tests/config/algorithms/test_n2n_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from careamics.config.algorithms import N2NAlgorithm


def test_instantiation():
"""Test the instantiation of the N2NAlgorithm class."""
model = {
"architecture": "UNet",
}
N2NAlgorithm(model=model)


def test_no_n2v2():
"""Check that an error is raised if the model is set for n2v2."""
model = {
"architecture": "UNet",
"n2v2": True,
}

with pytest.raises(ValueError):
N2NAlgorithm(model=model)


def test_no_final_activation(minimum_algorithm_supervised: dict):
"""Check that an error is raised if the model has a final activation."""
minimum_algorithm_supervised["model"] = {
"architecture": "UNet",
"final_activation": "ReLU",
}
with pytest.raises(ValueError):
N2NAlgorithm(**minimum_algorithm_supervised)
20 changes: 16 additions & 4 deletions tests/config/algorithms/test_n2v_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,26 @@ def test_n_channels_n2v():
N2VAlgorithm(algorithm="n2v", loss=loss, model=model)


def test_comaptiblity_of_number_of_channels(minimum_algorithm_n2v: dict):
"""Check that no error is thrown when instantiating the algorithm with a valid
number of in and out channels."""
def test_channels(minimum_algorithm_n2v: dict):
"""Check that error is thrown if the number of channels are different."""
minimum_algorithm_n2v["model"] = {
"architecture": "UNet",
"in_channels": 2,
"num_classes": 2,
"n2v2": False,
}

N2VAlgorithm(**minimum_algorithm_n2v)

minimum_algorithm_n2v["model"]["num_classes"] = 3
with pytest.raises(ValueError):
N2VAlgorithm(**minimum_algorithm_n2v)


def test_no_final_activation(minimum_algorithm_n2v: dict):
"""Check that an error is raised if the model has a final activation."""
minimum_algorithm_n2v["model"] = {
"architecture": "UNet",
"final_activation": "ReLU",
}
with pytest.raises(ValueError):
N2VAlgorithm(**minimum_algorithm_n2v)
Loading

0 comments on commit aad0c19

Please sign in to comment.