Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Allow Noise2Noise with different in and out channels #152

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions src/careamics/config/algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,6 @@ def algorithm_cross_validation(self: Self) -> Self:
"sure that `in_channels` and `num_classes` are the same."
)

# N2N
if self.algorithm == "n2n":
# n2n is only compatible with the UNet model
if not isinstance(self.model, UNetModel):
raise ValueError(
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
)
melisande-c marked this conversation as resolved.
Show resolved Hide resolved

# n2n requires the number of input and output channels to be the same
if self.model.in_channels != self.model.num_classes:
raise ValueError(
"N2N requires the same number of input and output channels. Make "
"sure that `in_channels` and `num_classes` are the same."
)

if self.algorithm == "care" or self.algorithm == "n2n":
if self.loss == "n2v":
raise ValueError("Supervised algorithms do not support loss `n2v`.")
Expand Down
25 changes: 17 additions & 8 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,20 +243,24 @@ def create_n2n_configuration(
use_augmentations: bool = True,
independent_channels: bool = False,
loss: Literal["mae", "mse"] = "mae",
n_channels: int = 1,
n_channels_in: int = 1,
n_channels_out: int = -1,
logger: Literal["wandb", "tensorboard", "none"] = "none",
model_kwargs: Optional[dict] = None,
) -> Configuration:
"""
Create a configuration for training Noise2Noise.
Create a configuration for training CARE.
jdeschamps marked this conversation as resolved.
Show resolved Hide resolved

If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
2.

If "C" is present in `axes`, then you need to set `n_channels` to the number of
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
channels. Likewise, if you set the number of channels, then "C" must be present in
`axes`.

To set the number of output channels, use the `n_channels_out` parameter. If it is
not specified, it will be assumed to be equal to `n_channels_in`.

By default, all channels are trained together. To train all channels independently,
set `independent_channels` to True.

Expand All @@ -283,8 +287,10 @@ def create_n2n_configuration(
Whether to train all channels independently, by default False.
loss : Literal["mae", "mse"], optional
Loss function to use, by default "mae".
n_channels : int, optional
Number of channels (in and out), by default 1.
n_channels_in : int, optional
Number of channels in, by default 1.
n_channels_out : int, optional
Number of channels out, by default -1.
logger : Literal["wandb", "tensorboard", "none"], optional
Logger to use, by default "none".
model_kwargs : dict, optional
Expand All @@ -293,8 +299,11 @@ def create_n2n_configuration(
Returns
-------
Configuration
Configuration for training Noise2Noise.
Configuration for training CARE.
jdeschamps marked this conversation as resolved.
Show resolved Hide resolved
"""
if n_channels_out == -1:
n_channels_out = n_channels_in

return _create_supervised_configuration(
algorithm="n2n",
experiment_name=experiment_name,
Expand All @@ -306,8 +315,8 @@ def create_n2n_configuration(
use_augmentations=use_augmentations,
independent_channels=independent_channels,
loss=loss,
n_channels_in=n_channels,
n_channels_out=n_channels,
n_channels_in=n_channels_in,
n_channels_out=n_channels_out,
logger=logger,
model_kwargs=model_kwargs,
)
Expand Down
9 changes: 4 additions & 5 deletions tests/config/test_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,19 @@ def test_algorithm_constraints(algorithm: str, loss: str, model: dict):
AlgorithmConfig(algorithm=algorithm, loss=loss, model=model)


@pytest.mark.parametrize("algorithm", ["n2v", "n2n"])
def test_n_channels_n2v_and_n2n(algorithm):
"""Check that an error is raised if n2v and n2n have different number of channels in
def test_n_channels_n2v():
"""Check that an error is raised if n2v has different number of channels in
input and output."""
model = {
"architecture": "UNet",
"in_channels": 1,
"num_classes": 2,
"n2v2": False,
}
loss = "mae" if algorithm == "n2n" else "n2v"
loss = "n2v"

with pytest.raises(ValueError):
AlgorithmConfig(algorithm=algorithm, loss=loss, model=model)
AlgorithmConfig(algorithm="n2v", loss=loss, model=model)


@pytest.mark.parametrize(
Expand Down
26 changes: 21 additions & 5 deletions tests/config/test_configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_n2n_channels_errors():
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
n_channels=5,
n_channels_in=5,
)


Expand Down Expand Up @@ -122,27 +122,43 @@ def test_n2n_independent_channels(ind_channels):
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
n_channels=4,
n_channels_in=4,
independent_channels=ind_channels,
)
assert config.algorithm_config.model.independent_channels == ind_channels


def test_n2n_chanels_equal():
"""Test that channels in and out are equal."""
def test_n2n_channels_equal():
"""Test that channels in and out are equal if only channels_in is set."""
config = create_n2n_configuration(
experiment_name="test",
data_type="tiff",
axes="YXC",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
n_channels=4,
n_channels_in=4,
)
assert config.algorithm_config.model.in_channels == 4
assert config.algorithm_config.model.num_classes == 4


def test_n2n_channels_different():
"""Test that channels in and out can be different."""
config = create_n2n_configuration(
experiment_name="test",
data_type="tiff",
axes="YXC",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
n_channels_in=4,
n_channels_out=5,
)
assert config.algorithm_config.model.in_channels == 4
assert config.algorithm_config.model.num_classes == 5


def test_care_configuration():
"""Test that CARE configuration can be created."""
config = create_care_configuration(
Expand Down
Loading