diff --git a/src/careamics/config/algorithms/care_algorithm_model.py b/src/careamics/config/algorithms/care_algorithm_model.py index 1d2613d93..57435112a 100644 --- a/src/careamics/config/algorithms/care_algorithm_model.py +++ b/src/careamics/config/algorithms/care_algorithm_model.py @@ -1,10 +1,14 @@ """CARE algorithm configuration.""" -from typing import Literal +from typing import Annotated, Literal -from pydantic import field_validator +from pydantic import AfterValidator from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_without_final_activation, + model_without_n2v2, +) from .unet_algorithm_model import UNetBasedAlgorithm @@ -26,25 +30,9 @@ class CAREAlgorithm(UNetBasedAlgorithm): loss: Literal["mae", "mse"] = "mae" """CARE-compatible loss function.""" - @classmethod - @field_validator("model") - def model_without_n2v2(cls, value: UNetModel) -> UNetModel: - """Validate that the model does not have the n2v2 attribute. - - Parameters - ---------- - value : UNetModel - Model to validate. - - Returns - ------- - UNetModel - The validated model. - """ - if value.n2v2: - raise ValueError( - "The N2N algorithm does not support the `n2v2` attribute. " - "Set it to `False`." - ) - - return value + model: Annotated[ + UNetModel, + AfterValidator(model_without_n2v2), + AfterValidator(model_without_final_activation), + ] + """UNet without a final activation function and without the `n2v2` modifications.""" diff --git a/src/careamics/config/algorithms/n2n_algorithm_model.py b/src/careamics/config/algorithms/n2n_algorithm_model.py index 384e0ef7f..08bdbaa72 100644 --- a/src/careamics/config/algorithms/n2n_algorithm_model.py +++ b/src/careamics/config/algorithms/n2n_algorithm_model.py @@ -1,16 +1,20 @@ """N2N Algorithm configuration.""" -from typing import Literal +from typing import Annotated, Literal -from pydantic import field_validator +from pydantic import AfterValidator from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_without_final_activation, + model_without_n2v2, +) from .unet_algorithm_model import UNetBasedAlgorithm class N2NAlgorithm(UNetBasedAlgorithm): - """N2N Algorithm configuration.""" + """Noise2Noise Algorithm configuration.""" algorithm: Literal["n2n"] = "n2n" """N2N Algorithm name.""" @@ -18,25 +22,9 @@ class N2NAlgorithm(UNetBasedAlgorithm): loss: Literal["mae", "mse"] = "mae" """N2N-compatible loss function.""" - @classmethod - @field_validator("model") - def model_without_n2v2(cls, value: UNetModel) -> UNetModel: - """Validate that the model does not have the n2v2 attribute. - - Parameters - ---------- - value : UNetModel - Model to validate. - - Returns - ------- - UNetModel - The validated model. - """ - if value.n2v2: - raise ValueError( - "The N2N algorithm does not support the `n2v2` attribute. " - "Set it to `False`." - ) - - return value + model: Annotated[ + UNetModel, + AfterValidator(model_without_n2v2), + AfterValidator(model_without_final_activation), + ] + """UNet without a final activation function and without the `n2v2` modifications.""" diff --git a/src/careamics/config/algorithms/n2v_algorithm_model.py b/src/careamics/config/algorithms/n2v_algorithm_model.py index a9defbfbc..235b7b2a8 100644 --- a/src/careamics/config/algorithms/n2v_algorithm_model.py +++ b/src/careamics/config/algorithms/n2v_algorithm_model.py @@ -1,9 +1,14 @@ """"N2V Algorithm configuration.""" -from typing import Literal +from typing import Annotated, Literal -from pydantic import model_validator -from typing_extensions import Self +from pydantic import AfterValidator + +from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_matching_in_out_channels, + model_without_final_activation, +) from .unet_algorithm_model import UNetBasedAlgorithm @@ -17,19 +22,8 @@ class N2VAlgorithm(UNetBasedAlgorithm): loss: Literal["n2v"] = "n2v" """N2V loss function.""" - @model_validator(mode="after") - def algorithm_cross_validation(self: Self) -> Self: - """Validate the algorithm model for N2V. - - Returns - ------- - Self - The validated model. - """ - if self.model.in_channels != self.model.num_classes: - raise ValueError( - "N2V requires the same number of input and output channels. Make " - "sure that `in_channels` and `num_classes` are the same." - ) - - return self + model: Annotated[ + UNetModel, + AfterValidator(model_matching_in_out_channels), + AfterValidator(model_without_final_activation), + ] diff --git a/src/careamics/config/validators/__init__.py b/src/careamics/config/validators/__init__.py index 53ddbf8db..06a0f103f 100644 --- a/src/careamics/config/validators/__init__.py +++ b/src/careamics/config/validators/__init__.py @@ -1,5 +1,16 @@ """Validator utilities.""" -__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] +__all__ = [ + "check_axes_validity", + "model_matching_in_out_channels", + "model_without_final_activation", + "model_without_n2v2", + "patch_size_ge_than_8_power_of_2", +] +from .model_validators import ( + model_matching_in_out_channels, + model_without_final_activation, + model_without_n2v2, +) from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 diff --git a/src/careamics/config/validators/model_validators.py b/src/careamics/config/validators/model_validators.py new file mode 100644 index 000000000..66f627855 --- /dev/null +++ b/src/careamics/config/validators/model_validators.py @@ -0,0 +1,84 @@ +"""Architecture model validators.""" + +from careamics.config.architectures import UNetModel + + +def model_without_n2v2(model: UNetModel) -> UNetModel: + """Validate that the Unet model does not have the n2v2 attribute. + + Parameters + ---------- + model : UNetModel + Model to validate. + + Returns + ------- + UNetModel + The validated model. + + Raises + ------ + ValueError + If the model has the `n2v2` attribute set to `True`. + """ + if model.n2v2: + raise ValueError( + "The algorithm does not support the `n2v2` attribute in the model. " + "Set it to `False`." + ) + + return model + + +def model_without_final_activation(model: UNetModel) -> UNetModel: + """Validate that the UNet model does not have the final_activation. + + Parameters + ---------- + model : UNetModel + Model to validate. + + Returns + ------- + UNetModel + The validated model. + + Raises + ------ + ValueError + If the model has the final_activation attribute set. + """ + if model.final_activation != "None": + raise ValueError( + "The algorithm does not support a `final_activation` in the model. " + 'Set it to `"None"`.' + ) + + return model + + +def model_matching_in_out_channels(model: UNetModel) -> UNetModel: + """Validate that the UNet model has the same number of channel inputs and outputs. + + Parameters + ---------- + model : UNetModel + Model to validate. + + Returns + ------- + UNetModel + Validated model. + + Raises + ------ + ValueError + If the model has different number of input and output channels. + """ + if model.num_classes != model.in_channels: + raise ValueError( + "The algorithm requires the same number of input and output channels. " + "Make sure that `in_channels` and `num_classes` are equal." + ) + + return model diff --git a/tests/config/algorithms/test_care_algorithm_model.py b/tests/config/algorithms/test_care_algorithm_model.py new file mode 100644 index 000000000..aab357c3b --- /dev/null +++ b/tests/config/algorithms/test_care_algorithm_model.py @@ -0,0 +1,32 @@ +import pytest + +from careamics.config.algorithms import CAREAlgorithm + + +def test_instantiation(): + """Test the instantiation of the CAREAlgorithm class.""" + model = { + "architecture": "UNet", + } + CAREAlgorithm(model=model) + + +def test_no_n2v2(): + """Check that an error is raised if the model is set for n2v2.""" + model = { + "architecture": "UNet", + "n2v2": True, + } + + with pytest.raises(ValueError): + CAREAlgorithm(model=model) + + +def test_no_final_activation(minimum_algorithm_supervised: dict): + """Check that an error is raised if the model has a final activation.""" + minimum_algorithm_supervised["model"] = { + "architecture": "UNet", + "final_activation": "ReLU", + } + with pytest.raises(ValueError): + CAREAlgorithm(**minimum_algorithm_supervised) diff --git a/tests/config/algorithms/test_n2n_algorithm_model.py b/tests/config/algorithms/test_n2n_algorithm_model.py new file mode 100644 index 000000000..ce439f4c5 --- /dev/null +++ b/tests/config/algorithms/test_n2n_algorithm_model.py @@ -0,0 +1,32 @@ +import pytest + +from careamics.config.algorithms import N2NAlgorithm + + +def test_instantiation(): + """Test the instantiation of the N2NAlgorithm class.""" + model = { + "architecture": "UNet", + } + N2NAlgorithm(model=model) + + +def test_no_n2v2(): + """Check that an error is raised if the model is set for n2v2.""" + model = { + "architecture": "UNet", + "n2v2": True, + } + + with pytest.raises(ValueError): + N2NAlgorithm(model=model) + + +def test_no_final_activation(minimum_algorithm_supervised: dict): + """Check that an error is raised if the model has a final activation.""" + minimum_algorithm_supervised["model"] = { + "architecture": "UNet", + "final_activation": "ReLU", + } + with pytest.raises(ValueError): + N2NAlgorithm(**minimum_algorithm_supervised) diff --git a/tests/config/algorithms/test_n2v_algorithm_model.py b/tests/config/algorithms/test_n2v_algorithm_model.py index a52337f66..ad971896c 100644 --- a/tests/config/algorithms/test_n2v_algorithm_model.py +++ b/tests/config/algorithms/test_n2v_algorithm_model.py @@ -18,14 +18,26 @@ def test_n_channels_n2v(): N2VAlgorithm(algorithm="n2v", loss=loss, model=model) -def test_comaptiblity_of_number_of_channels(minimum_algorithm_n2v: dict): - """Check that no error is thrown when instantiating the algorithm with a valid - number of in and out channels.""" +def test_channels(minimum_algorithm_n2v: dict): + """Check that error is thrown if the number of channels are different.""" minimum_algorithm_n2v["model"] = { "architecture": "UNet", "in_channels": 2, "num_classes": 2, "n2v2": False, } - N2VAlgorithm(**minimum_algorithm_n2v) + + minimum_algorithm_n2v["model"]["num_classes"] = 3 + with pytest.raises(ValueError): + N2VAlgorithm(**minimum_algorithm_n2v) + + +def test_no_final_activation(minimum_algorithm_n2v: dict): + """Check that an error is raised if the model has a final activation.""" + minimum_algorithm_n2v["model"] = { + "architecture": "UNet", + "final_activation": "ReLU", + } + with pytest.raises(ValueError): + N2VAlgorithm(**minimum_algorithm_n2v) diff --git a/tests/config/validators/test_model_validators.py b/tests/config/validators/test_model_validators.py new file mode 100644 index 000000000..5589d7ba6 --- /dev/null +++ b/tests/config/validators/test_model_validators.py @@ -0,0 +1,52 @@ +import pytest + +from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_matching_in_out_channels, + model_without_final_activation, + model_without_n2v2, +) + + +def test_model_without_n2v2(): + """Test the validation of the model without the `n2v2` attribute.""" + model = UNetModel(architecture="UNet", final_activation="None", n2v2=False) + assert model_without_n2v2(model) == model + + model = UNetModel(architecture="UNet", final_activation="None", n2v2=True) + with pytest.raises(ValueError): + model_without_n2v2(model) + + +def test_model_without_final_activation(): + """Test the validation of the model without the `final_activation` attribute.""" + model = UNetModel( + architecture="UNet", + final_activation="None", + ) + assert model_without_final_activation(model) == model + + model = UNetModel( + architecture="UNet", + final_activation="ReLU", + ) + with pytest.raises(ValueError): + model_without_final_activation(model) + + +def test_model_matching_in_out_channels(): + """Test the validation of the model with matching in and out channels.""" + model = UNetModel( + architecture="UNet", + in_channels=1, + num_classes=1, + ) + assert model_matching_in_out_channels(model) == model + + model = UNetModel( + architecture="UNet", + in_channels=1, + num_classes=2, + ) + with pytest.raises(ValueError): + model_matching_in_out_channels(model)