Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refac: Algorithm specific configurations #344

Merged
merged 17 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ repos:
args: ["--config-file", "mypy.ini"]
additional_dependencies:
- numpy<2.0.0
- pydantic
- types-PyYAML
- types-setuptools

Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[mypy]
ignore_missing_imports = True
plugins = pydantic.mypy

[mypy-careamics.lvae_training.*]
follow_imports = skip
Expand Down
15 changes: 13 additions & 2 deletions src/careamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
except PackageNotFoundError:
__version__ = "uninstalled"

__all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]
__all__ = [
"CAREamist",
"Configuration",
"ConfigurationFactory",
"load_configuration",
"save_configuration",
]

from .careamist import CAREamist
from .config import Configuration, load_configuration, save_configuration
from .config import (
Configuration,
ConfigurationFactory,
load_configuration,
save_configuration,
)
7 changes: 4 additions & 3 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
self.cfg = source

# instantiate model
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
)
Expand All @@ -157,7 +157,8 @@ def __init__(
self.cfg = load_configuration(source)

# instantiate model
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
# TODO call model factory here
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
) # type: ignore
Expand Down
52 changes: 37 additions & 15 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,63 @@
"""Configuration module."""
"""CAREamics Pydantic configuration models.

To maintain clarity at the module level, we follow the following naming conventions:
`*_model` is specific for sub-configurations (e.g. architecture, data, algorithm),
while `*_configuration` is reserved for the main configuration models, including the
`Configuration` base class and its algorithm-specific child classes.
"""

__all__ = [
"AlgorithmFactory",
"CAREAlgorithm",
"CAREConfiguration",
"CheckpointModel",
"Configuration",
"CustomModel",
"ConfigurationFactory",
"DataConfig",
"FCNAlgorithmConfig",
"DataFactory",
"GaussianMixtureNMConfig",
"GeneralDataConfig",
"InferenceConfig",
"LVAELossConfig",
"MultiChannelNMConfig",
"N2NAlgorithm",
"N2NConfiguration",
"N2VAlgorithm",
"N2VConfiguration",
"N2VDataConfig",
"TrainingConfig",
"VAEAlgorithmConfig",
"clear_custom_models",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
"create_care_configuration",
"create_n2n_configuration",
"create_n2v_configuration",
"load_configuration",
"register_model",
"save_configuration",
]
from .architectures import CustomModel, clear_custom_models, register_model

from .algorithms import (
CAREAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
UNetBasedAlgorithm,
VAEBasedAlgorithm,
)
from .callback_model import CheckpointModel
from .care_configuration import CAREConfiguration
from .configuration import Configuration
from .configuration_factory import (
AlgorithmFactory,
ConfigurationFactory,
DataFactory,
create_care_configuration,
create_n2n_configuration,
create_n2v_configuration,
)
from .configuration_model import (
Configuration,
load_configuration,
save_configuration,
)
from .data_model import DataConfig
from .fcn_algorithm_model import FCNAlgorithmConfig
from .configuration_io import load_configuration, save_configuration
from .data import DataConfig, GeneralDataConfig, N2VDataConfig
from .inference_model import InferenceConfig
from .loss_model import LVAELossConfig
from .n2n_configuration import N2NConfiguration
from .n2v_configuration import N2VConfiguration
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
from .training_model import TrainingConfig
from .vae_algorithm_model import VAEAlgorithmConfig
15 changes: 15 additions & 0 deletions src/careamics/config/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Algorithm configurations."""

__all__ = [
"CAREAlgorithm",
"N2NAlgorithm",
"N2VAlgorithm",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
]

from .care_algorithm_model import CAREAlgorithm
from .n2n_algorithm_model import N2NAlgorithm
from .n2v_algorithm_model import N2VAlgorithm
from .unet_algorithm_model import UNetBasedAlgorithm
from .vae_algorithm_model import VAEBasedAlgorithm
50 changes: 50 additions & 0 deletions src/careamics/config/algorithms/care_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""CARE algorithm configuration."""

from typing import Literal

from pydantic import field_validator

from careamics.config.architectures import UNetModel

from .unet_algorithm_model import UNetBasedAlgorithm


class CAREAlgorithm(UNetBasedAlgorithm):
"""CARE algorithm configuration.

Attributes
----------
algorithm : "care"
CARE Algorithm name.
loss : {"mae", "mse"}
CARE-compatible loss function.
"""

algorithm: Literal["care"] = "care"
"""CARE Algorithm name."""

loss: Literal["mae", "mse"] = "mae"
"""CARE-compatible loss function."""

@classmethod
@field_validator("model")
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
"""Validate that the model does not have the n2v2 attribute.

Parameters
----------
value : UNetModel
Model to validate.

Returns
-------
UNetModel
The validated model.
"""
if value.n2v2:
raise ValueError(
"The N2N algorithm does not support the `n2v2` attribute. "
"Set it to `False`."
)

return value
42 changes: 42 additions & 0 deletions src/careamics/config/algorithms/n2n_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""N2N Algorithm configuration."""

from typing import Literal

from pydantic import field_validator

from careamics.config.architectures import UNetModel

from .unet_algorithm_model import UNetBasedAlgorithm


class N2NAlgorithm(UNetBasedAlgorithm):
"""N2N Algorithm configuration."""

algorithm: Literal["n2n"] = "n2n"
"""N2N Algorithm name."""

loss: Literal["mae", "mse"] = "mae"
"""N2N-compatible loss function."""

@classmethod
@field_validator("model")
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
"""Validate that the model does not have the n2v2 attribute.

Parameters
----------
value : UNetModel
Model to validate.

Returns
-------
UNetModel
The validated model.
"""
if value.n2v2:
raise ValueError(
"The N2N algorithm does not support the `n2v2` attribute. "
"Set it to `False`."
)

return value
35 changes: 35 additions & 0 deletions src/careamics/config/algorithms/n2v_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
""""N2V Algorithm configuration."""

from typing import Literal

from pydantic import model_validator
from typing_extensions import Self

from .unet_algorithm_model import UNetBasedAlgorithm


class N2VAlgorithm(UNetBasedAlgorithm):
"""N2V Algorithm configuration."""

algorithm: Literal["n2v"] = "n2v"
"""N2V Algorithm name."""

loss: Literal["n2v"] = "n2v"
"""N2V loss function."""

@model_validator(mode="after")
def algorithm_cross_validation(self: Self) -> Self:
"""Validate the algorithm model for N2V.

Returns
-------
Self
The validated model.
"""
if self.model.in_channels != self.model.num_classes:
raise ValueError(
"N2V requires the same number of input and output channels. Make "
"sure that `in_channels` and `num_classes` are the same."
)

return self
101 changes: 101 additions & 0 deletions src/careamics/config/algorithms/unet_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""UNet-based algorithm Pydantic model."""

from pprint import pformat
from typing import Literal

from pydantic import BaseModel, ConfigDict

from careamics.config.architectures import UNetModel
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel


class UNetBasedAlgorithm(BaseModel):
"""UNet-based algorithm configuration.

This Pydantic model validates the parameters governing the components of the
training algorithm: which algorithm, loss function, model architecture, optimizer,
and learning rate scheduler to use.

Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
allows you to register your own architecture and select it using its name as
`name` in the custom pydantic model.

Attributes
----------
algorithm : {"n2v", "care", "n2n", "custom"}
Algorithm to use.
loss : {"n2v", "mae", "mse"}
Loss function to use.
model : UNetModel or CustomModel
Model architecture to use.
jdeschamps marked this conversation as resolved.
Show resolved Hide resolved
optimizer : OptimizerModel, optional
Optimizer to use.
lr_scheduler : LrSchedulerModel, optional
Learning rate scheduler to use.

Raises
------
ValueError
Algorithm parameter type validation errors.
ValueError
If the algorithm, loss and model are not compatible.

Examples
--------
Minimum example:
>>> from careamics.config import AlgorithmFactory
>>> config_dict = {
... "algorithm": "n2v",
... "loss": "n2v",
... "model": {
... "architecture": "UNet",
... }
... }
>>> config = AlgorithmFactory(algorithm=config_dict).algorithm
"""

# Pydantic class configuration
model_config = ConfigDict(
protected_namespaces=(), # allows to use model_* as a field name
validate_assignment=True,
extra="allow",
)

# Mandatory fields
algorithm: Literal["n2v", "care", "n2n"]
"""Algorithm name, as defined in SupportedAlgorithm."""

loss: Literal["n2v", "mae", "mse"]
"""Loss function to use, as defined in SupportedLoss."""

model: UNetModel
"""UNet model configuration."""

# Optional fields
optimizer: OptimizerModel = OptimizerModel()
"""Optimizer to use, defined in SupportedOptimizer."""

lr_scheduler: LrSchedulerModel = LrSchedulerModel()
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""

def __str__(self) -> str:
"""Pretty string representing the configuration.

Returns
-------
str
Pretty string.
"""
return pformat(self.model_dump())

@classmethod
def get_compatible_algorithms(cls) -> list[str]:
"""Get the list of compatible algorithms.

Returns
-------
list of str
List of compatible algorithms.
"""
return ["n2v", "care", "n2n"]
Loading
Loading