diff --git a/src/careamics/config/algorithm_model.py b/src/careamics/config/algorithm_model.py index ddfe81e4a..66dcf5be4 100644 --- a/src/careamics/config/algorithm_model.py +++ b/src/careamics/config/algorithm_model.py @@ -134,21 +134,6 @@ def algorithm_cross_validation(self: Self) -> Self: "sure that `in_channels` and `num_classes` are the same." ) - # N2N - if self.algorithm == "n2n": - # n2n is only compatible with the UNet model - if not isinstance(self.model, UNetModel): - raise ValueError( - f"Model for algorithm {self.algorithm} must be a `UNetModel`." - ) - - # n2n requires the number of input and output channels to be the same - if self.model.in_channels != self.model.num_classes: - raise ValueError( - "N2N requires the same number of input and output channels. Make " - "sure that `in_channels` and `num_classes` are the same." - ) - if self.algorithm == "care" or self.algorithm == "n2n": if self.loss == "n2v": raise ValueError("Supervised algorithms do not support loss `n2v`.") diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index e28972b0d..6f3d81f9c 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -243,7 +243,8 @@ def create_n2n_configuration( use_augmentations: bool = True, independent_channels: bool = False, loss: Literal["mae", "mse"] = "mae", - n_channels: int = 1, + n_channels_in: int = 1, + n_channels_out: int = -1, logger: Literal["wandb", "tensorboard", "none"] = "none", model_kwargs: Optional[dict] = None, ) -> Configuration: @@ -253,10 +254,13 @@ def create_n2n_configuration( If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise 2. - If "C" is present in `axes`, then you need to set `n_channels` to the number of + If "C" is present in `axes`, then you need to set `n_channels_in` to the number of channels. Likewise, if you set the number of channels, then "C" must be present in `axes`. + To set the number of output channels, use the `n_channels_out` parameter. If it is + not specified, it will be assumed to be equal to `n_channels_in`. + By default, all channels are trained together. To train all channels independently, set `independent_channels` to True. @@ -283,8 +287,10 @@ def create_n2n_configuration( Whether to train all channels independently, by default False. loss : Literal["mae", "mse"], optional Loss function to use, by default "mae". - n_channels : int, optional - Number of channels (in and out), by default 1. + n_channels_in : int, optional + Number of channels in, by default 1. + n_channels_out : int, optional + Number of channels out, by default -1. logger : Literal["wandb", "tensorboard", "none"], optional Logger to use, by default "none". model_kwargs : dict, optional @@ -295,6 +301,9 @@ def create_n2n_configuration( Configuration Configuration for training Noise2Noise. """ + if n_channels_out == -1: + n_channels_out = n_channels_in + return _create_supervised_configuration( algorithm="n2n", experiment_name=experiment_name, @@ -306,8 +315,8 @@ def create_n2n_configuration( use_augmentations=use_augmentations, independent_channels=independent_channels, loss=loss, - n_channels_in=n_channels, - n_channels_out=n_channels, + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, logger=logger, model_kwargs=model_kwargs, ) diff --git a/tests/config/test_algorithm_model.py b/tests/config/test_algorithm_model.py index d8ce3b996..e6ba3c398 100644 --- a/tests/config/test_algorithm_model.py +++ b/tests/config/test_algorithm_model.py @@ -64,9 +64,8 @@ def test_algorithm_constraints(algorithm: str, loss: str, model: dict): AlgorithmConfig(algorithm=algorithm, loss=loss, model=model) -@pytest.mark.parametrize("algorithm", ["n2v", "n2n"]) -def test_n_channels_n2v_and_n2n(algorithm): - """Check that an error is raised if n2v and n2n have different number of channels in +def test_n_channels_n2v(): + """Check that an error is raised if n2v has different number of channels in input and output.""" model = { "architecture": "UNet", @@ -74,10 +73,10 @@ def test_n_channels_n2v_and_n2n(algorithm): "num_classes": 2, "n2v2": False, } - loss = "mae" if algorithm == "n2n" else "n2v" + loss = "n2v" with pytest.raises(ValueError): - AlgorithmConfig(algorithm=algorithm, loss=loss, model=model) + AlgorithmConfig(algorithm="n2v", loss=loss, model=model) @pytest.mark.parametrize( diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py index 52b7c3cb3..4dee1a697 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factory.py @@ -94,7 +94,7 @@ def test_n2n_channels_errors(): patch_size=[64, 64], batch_size=8, num_epochs=100, - n_channels=5, + n_channels_in=5, ) @@ -122,14 +122,14 @@ def test_n2n_independent_channels(ind_channels): patch_size=[64, 64], batch_size=8, num_epochs=100, - n_channels=4, + n_channels_in=4, independent_channels=ind_channels, ) assert config.algorithm_config.model.independent_channels == ind_channels -def test_n2n_chanels_equal(): - """Test that channels in and out are equal.""" +def test_n2n_channels_equal(): + """Test that channels in and out are equal if only channels_in is set.""" config = create_n2n_configuration( experiment_name="test", data_type="tiff", @@ -137,12 +137,28 @@ def test_n2n_chanels_equal(): patch_size=[64, 64], batch_size=8, num_epochs=100, - n_channels=4, + n_channels_in=4, ) assert config.algorithm_config.model.in_channels == 4 assert config.algorithm_config.model.num_classes == 4 +def test_n2n_channels_different(): + """Test that channels in and out can be different.""" + config = create_n2n_configuration( + experiment_name="test", + data_type="tiff", + axes="YXC", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + n_channels_in=4, + n_channels_out=5, + ) + assert config.algorithm_config.model.in_channels == 4 + assert config.algorithm_config.model.num_classes == 5 + + def test_care_configuration(): """Test that CARE configuration can be created.""" config = create_care_configuration(