-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* for losses, optimizers, metrics, early_stopping, lr_scheduler, networks
- Loading branch information
1 parent
ad16f83
commit 3cd23e4
Showing
128 changed files
with
3,751 additions
and
3,203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .config import create_loss_config | ||
from .config import LossConfig | ||
from .enum import ClassificationLoss, ImplementedLoss | ||
from .factory import get_loss_function | ||
from .factory import get_loss_function_config, get_loss_function_from_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,77 @@ | ||
from copy import deepcopy | ||
from typing import Tuple | ||
from typing import Any, Tuple, Union | ||
|
||
import torch | ||
|
||
from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults | ||
from clinicadl.utils.factories import update_config_with_defaults | ||
|
||
from .config import LossConfig | ||
from .config import LossConfig, create_loss_function_config | ||
from .enum import ImplementedLoss | ||
from .utils import Loss | ||
|
||
|
||
def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, LossConfig]: | ||
def get_loss_function_config( | ||
name: Union[str, ImplementedLoss], **kwargs: Any | ||
) -> LossConfig: | ||
""" | ||
Factory function to get a loss function from PyTorch. | ||
Factory function to get a loss function configuration object from its name | ||
and parameters. | ||
Parameters | ||
---------- | ||
loss : LossConfig | ||
The config class with the parameters of the loss function. | ||
name : Union[str, ImplementedLoss] | ||
the name of the loss function. Check our documentation to know | ||
available losses. | ||
**kwargs : Any | ||
any parameter of the loss function. Check our documentation on losses to | ||
know these parameters. | ||
Returns | ||
------- | ||
nn.Module | ||
The loss function. | ||
LossConfig | ||
The updated config class: the arguments set to default will be updated | ||
the config object. Default values will be returned for the parameters | ||
not passed by the user. | ||
""" | ||
config = create_loss_function_config(name)(**kwargs) | ||
loss_class = getattr(torch.nn, config.name) | ||
|
||
update_config_with_defaults(config, function=loss_class.__init__) | ||
|
||
return config | ||
|
||
|
||
def get_loss_function_from_config( | ||
config: LossConfig, | ||
) -> Tuple[Loss, LossConfig]: | ||
""" | ||
Factory function to get a PyTorch loss function from a LossConfig instance. | ||
Parameters | ||
---------- | ||
config : LossConfig | ||
the configuration object. | ||
Returns | ||
------- | ||
Loss | ||
the loss function. | ||
LossConfig | ||
the updated config object: the arguments set to default will be updated | ||
with their effective values (the default values from the library). | ||
Useful for reproducibility. | ||
""" | ||
loss_class = getattr(torch.nn, config.loss) | ||
expected_args, config_dict = get_args_and_defaults(loss_class.__init__) | ||
for arg, value in config.model_dump().items(): | ||
if arg in expected_args and value != DefaultFromLibrary.YES: | ||
config_dict[arg] = value | ||
config = deepcopy(config) | ||
loss_class = getattr(torch.nn, config.name) | ||
|
||
update_config_with_defaults(config, function=loss_class.__init__) | ||
config_dict = config.model_dump(exclude={"name"}) | ||
|
||
config_dict_ = deepcopy(config_dict) | ||
# change list to tensors | ||
if "weight" in config_dict and config_dict["weight"] is not None: | ||
config_dict_["weight"] = torch.Tensor(config_dict_["weight"]) | ||
config_dict["weight"] = torch.Tensor(config_dict["weight"]) | ||
if "pos_weight" in config_dict and config_dict["pos_weight"] is not None: | ||
config_dict_["pos_weight"] = torch.Tensor(config_dict_["pos_weight"]) | ||
loss = loss_class(**config_dict_) | ||
config_dict["pos_weight"] = torch.Tensor(config_dict["pos_weight"]) | ||
|
||
updated_config = config.model_copy(update=config_dict) | ||
loss = loss_class(**config_dict) | ||
|
||
return loss, updated_config | ||
return loss, config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,5 @@ | ||
from typing import Callable, Union | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
from torch.nn.modules.loss import _Loss | ||
|
||
Loss = Union[ | ||
Callable[[Tensor], Tensor], | ||
Callable[[Tensor, Tensor], Tensor], | ||
_Loss, | ||
] | ||
Loss = Callable[[Tensor, Tensor], Tensor] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .config import ImplementedMetrics, MetricConfig, create_metric_config | ||
from .factory import get_metric, loss_to_metric | ||
from .config import ImplementedMetric, MetricConfig | ||
from .factory import get_metric_config, get_metric_from_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .base import MetricConfig | ||
from .enum import ImplementedMetrics | ||
from .enum import ImplementedMetric | ||
from .factory import create_metric_config |
Oops, something went wrong.