-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add algorithm model validators (#361)
### Description Following #354: `final_activation` was allowed to be set although it would break CARE, N2N and N2V. This PR adds a Pydantic validator preventing final activation for these algorithms. It also takes the opportunity to improve the Pydantic models by moving the validators to a different module and using annotations for cleaner code. - **What**: CARE, N2N and N2V do not allow final activation in the UNet. - **Why**: This would otherwise break the algorithm training. - **How**: Add model validators in the respective Pydantic models. ### Changes Made - **Added**: `model_validators.py`. - **Modified**: Algorithms model. ### Related Issues - Fixes #354 --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
- Loading branch information
1 parent
b5b40a6
commit aad0c19
Showing
9 changed files
with
266 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,30 @@ | ||
"""N2N Algorithm configuration.""" | ||
|
||
from typing import Literal | ||
from typing import Annotated, Literal | ||
|
||
from pydantic import field_validator | ||
from pydantic import AfterValidator | ||
|
||
from careamics.config.architectures import UNetModel | ||
from careamics.config.validators import ( | ||
model_without_final_activation, | ||
model_without_n2v2, | ||
) | ||
|
||
from .unet_algorithm_model import UNetBasedAlgorithm | ||
|
||
|
||
class N2NAlgorithm(UNetBasedAlgorithm): | ||
"""N2N Algorithm configuration.""" | ||
"""Noise2Noise Algorithm configuration.""" | ||
|
||
algorithm: Literal["n2n"] = "n2n" | ||
"""N2N Algorithm name.""" | ||
|
||
loss: Literal["mae", "mse"] = "mae" | ||
"""N2N-compatible loss function.""" | ||
|
||
@classmethod | ||
@field_validator("model") | ||
def model_without_n2v2(cls, value: UNetModel) -> UNetModel: | ||
"""Validate that the model does not have the n2v2 attribute. | ||
Parameters | ||
---------- | ||
value : UNetModel | ||
Model to validate. | ||
Returns | ||
------- | ||
UNetModel | ||
The validated model. | ||
""" | ||
if value.n2v2: | ||
raise ValueError( | ||
"The N2N algorithm does not support the `n2v2` attribute. " | ||
"Set it to `False`." | ||
) | ||
|
||
return value | ||
model: Annotated[ | ||
UNetModel, | ||
AfterValidator(model_without_n2v2), | ||
AfterValidator(model_without_final_activation), | ||
] | ||
"""UNet without a final activation function and without the `n2v2` modifications.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,16 @@ | ||
"""Validator utilities.""" | ||
|
||
__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] | ||
__all__ = [ | ||
"check_axes_validity", | ||
"model_matching_in_out_channels", | ||
"model_without_final_activation", | ||
"model_without_n2v2", | ||
"patch_size_ge_than_8_power_of_2", | ||
] | ||
|
||
from .model_validators import ( | ||
model_matching_in_out_channels, | ||
model_without_final_activation, | ||
model_without_n2v2, | ||
) | ||
from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
"""Architecture model validators.""" | ||
|
||
from careamics.config.architectures import UNetModel | ||
|
||
|
||
def model_without_n2v2(model: UNetModel) -> UNetModel: | ||
"""Validate that the Unet model does not have the n2v2 attribute. | ||
Parameters | ||
---------- | ||
model : UNetModel | ||
Model to validate. | ||
Returns | ||
------- | ||
UNetModel | ||
The validated model. | ||
Raises | ||
------ | ||
ValueError | ||
If the model has the `n2v2` attribute set to `True`. | ||
""" | ||
if model.n2v2: | ||
raise ValueError( | ||
"The algorithm does not support the `n2v2` attribute in the model. " | ||
"Set it to `False`." | ||
) | ||
|
||
return model | ||
|
||
|
||
def model_without_final_activation(model: UNetModel) -> UNetModel: | ||
"""Validate that the UNet model does not have the final_activation. | ||
Parameters | ||
---------- | ||
model : UNetModel | ||
Model to validate. | ||
Returns | ||
------- | ||
UNetModel | ||
The validated model. | ||
Raises | ||
------ | ||
ValueError | ||
If the model has the final_activation attribute set. | ||
""" | ||
if model.final_activation != "None": | ||
raise ValueError( | ||
"The algorithm does not support a `final_activation` in the model. " | ||
'Set it to `"None"`.' | ||
) | ||
|
||
return model | ||
|
||
|
||
def model_matching_in_out_channels(model: UNetModel) -> UNetModel: | ||
"""Validate that the UNet model has the same number of channel inputs and outputs. | ||
Parameters | ||
---------- | ||
model : UNetModel | ||
Model to validate. | ||
Returns | ||
------- | ||
UNetModel | ||
Validated model. | ||
Raises | ||
------ | ||
ValueError | ||
If the model has different number of input and output channels. | ||
""" | ||
if model.num_classes != model.in_channels: | ||
raise ValueError( | ||
"The algorithm requires the same number of input and output channels. " | ||
"Make sure that `in_channels` and `num_classes` are equal." | ||
) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
|
||
from careamics.config.algorithms import CAREAlgorithm | ||
|
||
|
||
def test_instantiation(): | ||
"""Test the instantiation of the CAREAlgorithm class.""" | ||
model = { | ||
"architecture": "UNet", | ||
} | ||
CAREAlgorithm(model=model) | ||
|
||
|
||
def test_no_n2v2(): | ||
"""Check that an error is raised if the model is set for n2v2.""" | ||
model = { | ||
"architecture": "UNet", | ||
"n2v2": True, | ||
} | ||
|
||
with pytest.raises(ValueError): | ||
CAREAlgorithm(model=model) | ||
|
||
|
||
def test_no_final_activation(minimum_algorithm_supervised: dict): | ||
"""Check that an error is raised if the model has a final activation.""" | ||
minimum_algorithm_supervised["model"] = { | ||
"architecture": "UNet", | ||
"final_activation": "ReLU", | ||
} | ||
with pytest.raises(ValueError): | ||
CAREAlgorithm(**minimum_algorithm_supervised) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
|
||
from careamics.config.algorithms import N2NAlgorithm | ||
|
||
|
||
def test_instantiation(): | ||
"""Test the instantiation of the N2NAlgorithm class.""" | ||
model = { | ||
"architecture": "UNet", | ||
} | ||
N2NAlgorithm(model=model) | ||
|
||
|
||
def test_no_n2v2(): | ||
"""Check that an error is raised if the model is set for n2v2.""" | ||
model = { | ||
"architecture": "UNet", | ||
"n2v2": True, | ||
} | ||
|
||
with pytest.raises(ValueError): | ||
N2NAlgorithm(model=model) | ||
|
||
|
||
def test_no_final_activation(minimum_algorithm_supervised: dict): | ||
"""Check that an error is raised if the model has a final activation.""" | ||
minimum_algorithm_supervised["model"] = { | ||
"architecture": "UNet", | ||
"final_activation": "ReLU", | ||
} | ||
with pytest.raises(ValueError): | ||
N2NAlgorithm(**minimum_algorithm_supervised) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.