Skip to content

Commit

Permalink
Uniformize modules (#683)
Browse files Browse the repository at this point in the history
* for losses, optimizers, metrics, early_stopping, lr_scheduler, networks
  • Loading branch information
thibaultdvx authored Nov 13, 2024
1 parent ad16f83 commit 3cd23e4
Show file tree
Hide file tree
Showing 128 changed files with 3,751 additions and 3,203 deletions.
4 changes: 2 additions & 2 deletions clinicadl/API_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.optim.optimizers.config import AdamConfig, OptimizerConfig
from clinicadl.optim.optimizers.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/commandline/modules_options/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.utils.early_stopping.config import EarlyStoppingConfig
from clinicadl.optim.early_stopping import EarlyStoppingConfig

# Early Stopping
patience = click.option(
Expand All @@ -14,8 +14,8 @@
)
tolerance = click.option(
"--tolerance",
type=get_type("tolerance", EarlyStoppingConfig),
default=get_default("tolerance", EarlyStoppingConfig),
type=get_type("min_delta", EarlyStoppingConfig),
default=get_default("min_delta", EarlyStoppingConfig),
help="Value for early stopping tolerance.",
show_default=True,
)
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.optimization.config import OptimizationConfig
from clinicadl.optim.config import OptimizationConfig

# Optimization
accumulation_steps = click.option(
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.optimization.optimizer import OptimizerConfig
from clinicadl.optim.optimizers import OptimizerConfig

# Optimizer
learning_rate = click.option(
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .config import create_loss_config
from .config import LossConfig
from .enum import ClassificationLoss, ImplementedLoss
from .factory import get_loss_function
from .factory import get_loss_function_config, get_loss_function_from_config
52 changes: 26 additions & 26 deletions clinicadl/losses/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from typing import Any, List, Optional, Type, Union

from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
PositiveFloat,
computed_field,
field_validator,
)

from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.factories import DefaultFromLibrary

from .enum import ImplementedLoss, Order, Reduction
Expand All @@ -26,37 +25,38 @@
"SmoothL1LossConfig",
"L1LossConfig",
"MSELossConfig",
"create_loss_config",
"create_loss_function_config",
]


class LossConfig(BaseModel, ABC):
class LossConfig(ClinicaDLConfig, ABC):
"""Base config class for the loss function."""

reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
@abstractmethod
def loss(self) -> ImplementedLoss:
"""ImplementedLoss.e name of the loss."""
def name(self) -> ImplementedLoss:
"""The name of the loss."""


class _WeightConfig(ClinicaDLConfig):
"""Base config class for 'weight' argument."""

weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES


class NLLLossConfig(LossConfig):
class NLLLossConfig(LossConfig, _WeightConfig):
"""Config class for Negative Log Likelihood loss."""

ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.NLL

Expand All @@ -79,7 +79,7 @@ class CrossEntropyLossConfig(NLLLossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.CROSS_ENTROPY

Expand All @@ -100,7 +100,7 @@ class BCELossConfig(LossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.BCE

Expand All @@ -121,7 +121,7 @@ class BCEWithLogitsLossConfig(BCELossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.BCE_LOGITS

Expand All @@ -144,15 +144,15 @@ def _recursive_float_check(cls, item):
return (isinstance(item, float) or isinstance(item, int)) and item >= 0


class MultiMarginLossConfig(LossConfig):
class MultiMarginLossConfig(LossConfig, _WeightConfig):
"""Config class for Multi Margin loss."""

p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.MULTI_MARGIN

Expand All @@ -164,7 +164,7 @@ class KLDivLossConfig(LossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.KLDIV

Expand All @@ -176,7 +176,7 @@ class HuberLossConfig(LossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.HUBER

Expand All @@ -188,7 +188,7 @@ class SmoothL1LossConfig(LossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.SMOOTH_L1

Expand All @@ -198,7 +198,7 @@ class L1LossConfig(LossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.L1

Expand All @@ -208,12 +208,12 @@ class MSELossConfig(LossConfig):

@computed_field
@property
def loss(self) -> ImplementedLoss:
def name(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.MSE


def create_loss_config(
def create_loss_function_config(
loss: Union[str, ImplementedLoss],
) -> Type[LossConfig]:
"""
Expand Down
75 changes: 54 additions & 21 deletions clinicadl/losses/factory.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,77 @@
from copy import deepcopy
from typing import Tuple
from typing import Any, Tuple, Union

import torch

from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults
from clinicadl.utils.factories import update_config_with_defaults

from .config import LossConfig
from .config import LossConfig, create_loss_function_config
from .enum import ImplementedLoss
from .utils import Loss


def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, LossConfig]:
def get_loss_function_config(
name: Union[str, ImplementedLoss], **kwargs: Any
) -> LossConfig:
"""
Factory function to get a loss function from PyTorch.
Factory function to get a loss function configuration object from its name
and parameters.
Parameters
----------
loss : LossConfig
The config class with the parameters of the loss function.
name : Union[str, ImplementedLoss]
the name of the loss function. Check our documentation to know
available losses.
**kwargs : Any
any parameter of the loss function. Check our documentation on losses to
know these parameters.
Returns
-------
nn.Module
The loss function.
LossConfig
The updated config class: the arguments set to default will be updated
the config object. Default values will be returned for the parameters
not passed by the user.
"""
config = create_loss_function_config(name)(**kwargs)
loss_class = getattr(torch.nn, config.name)

update_config_with_defaults(config, function=loss_class.__init__)

return config


def get_loss_function_from_config(
config: LossConfig,
) -> Tuple[Loss, LossConfig]:
"""
Factory function to get a PyTorch loss function from a LossConfig instance.
Parameters
----------
config : LossConfig
the configuration object.
Returns
-------
Loss
the loss function.
LossConfig
the updated config object: the arguments set to default will be updated
with their effective values (the default values from the library).
Useful for reproducibility.
"""
loss_class = getattr(torch.nn, config.loss)
expected_args, config_dict = get_args_and_defaults(loss_class.__init__)
for arg, value in config.model_dump().items():
if arg in expected_args and value != DefaultFromLibrary.YES:
config_dict[arg] = value
config = deepcopy(config)
loss_class = getattr(torch.nn, config.name)

update_config_with_defaults(config, function=loss_class.__init__)
config_dict = config.model_dump(exclude={"name"})

config_dict_ = deepcopy(config_dict)
# change list to tensors
if "weight" in config_dict and config_dict["weight"] is not None:
config_dict_["weight"] = torch.Tensor(config_dict_["weight"])
config_dict["weight"] = torch.Tensor(config_dict["weight"])
if "pos_weight" in config_dict and config_dict["pos_weight"] is not None:
config_dict_["pos_weight"] = torch.Tensor(config_dict_["pos_weight"])
loss = loss_class(**config_dict_)
config_dict["pos_weight"] = torch.Tensor(config_dict["pos_weight"])

updated_config = config.model_copy(update=config_dict)
loss = loss_class(**config_dict)

return loss, updated_config
return loss, config
9 changes: 2 additions & 7 deletions clinicadl/losses/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import Callable, Union
from typing import Callable

from torch import Tensor
from torch.nn.modules.loss import _Loss

Loss = Union[
Callable[[Tensor], Tensor],
Callable[[Tensor, Tensor], Tensor],
_Loss,
]
Loss = Callable[[Tensor, Tensor], Tensor]
4 changes: 2 additions & 2 deletions clinicadl/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .config import ImplementedMetrics, MetricConfig, create_metric_config
from .factory import get_metric, loss_to_metric
from .config import ImplementedMetric, MetricConfig
from .factory import get_metric_config, get_metric_from_config
2 changes: 1 addition & 1 deletion clinicadl/metrics/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import MetricConfig
from .enum import ImplementedMetrics
from .enum import ImplementedMetric
from .factory import create_metric_config
Loading

0 comments on commit 3cd23e4

Please sign in to comment.