Skip to content

Commit

Permalink
Minor changes in optim and networks (#652)
Browse files Browse the repository at this point in the history
* correction of special cases in optimizer

* add Implemented Networks in init
  • Loading branch information
thibaultdvx authored Sep 23, 2024
1 parent 574f7a0 commit 8e7bcd8
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 149 deletions.
2 changes: 1 addition & 1 deletion clinicadl/monai_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .config import create_network_config
from .config import ImplementedNetworks, NetworkConfig, create_network_config
from .factory import get_network
1 change: 1 addition & 0 deletions clinicadl/monai_networks/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import NetworkConfig
from .factory import create_network_config
from .utils.enum import ImplementedNetworks
5 changes: 5 additions & 0 deletions clinicadl/optim/lr_scheduler/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def get_lr_scheduler(
config_dict_["min_lr"].append(
config_dict["min_lr"]["ELSE"]
) # ELSE must be the last group
else:
default_min_lr = get_args_and_defaults(scheduler_class.__init__)[1][
"min_lr"
]
config_dict_["min_lr"].append(default_min_lr)
scheduler = scheduler_class(optimizer, **config_dict_)

updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict)
Expand Down
98 changes: 7 additions & 91 deletions clinicadl/optim/optimizer/factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict, Iterable, Iterator, List, Tuple
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn
import torch.optim as optim

from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults

from .config import OptimizerConfig
from .utils import get_params_in_groups, get_params_not_in_groups


def get_optimizer(
Expand Down Expand Up @@ -45,23 +45,16 @@ def get_optimizer(
list_args_groups = network.parameters()
else:
list_args_groups = []
union_groups = set()
args_groups = sorted(args_groups.items()) # order in the list is important
for group, args in args_groups:
params, params_names = _get_params_in_group(network, group)
params, _ = get_params_in_groups(network, group)
args.update({"params": params})
list_args_groups.append(args)
union_groups.update(set(params_names))

other_params = _get_params_not_in_group(network, union_groups)
try:
next(other_params)
except StopIteration: # there is no other param in the network
pass
else:
other_params = _get_params_not_in_group(
network, union_groups
) # reset the generator
other_params, params_names = get_params_not_in_groups(
network, [group for group, _ in args_groups]
)
if len(params_names) > 0:
list_args_groups.append({"params": other_params})

optimizer = optimizer_class(list_args_groups, **args_global)
Expand Down Expand Up @@ -126,80 +119,3 @@ def _regroup_args(
args_global[arg] = value

return args_groups, args_global


def _get_params_in_group(
network: nn.Module, group: str
) -> Tuple[Iterator[torch.Tensor], List[str]]:
"""
Gets the parameters of a specific group of a neural network.
Parameters
----------
network : nn.Module
The neural network.
group : str
The name of the group, e.g. a layer or a block.
If it is a sub-block, the hierarchy should be
specified with "." (see examples).
Will work even if the group is reduced to a base layer
(e.g. group = "dense.weight" or "dense.bias").
Returns
-------
Iterator[torch.Tensor]
A generator that contains the parameters of the group.
List[str]
The name of all the parameters in the group.
Examples
--------
>>> net = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 1, kernel_size=3)),
("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))),
]
)
)
>>> generator, params_names = _get_params_in_group(network, "final.dense1")
>>> params_names
["final.dense1.weight", "final.dense1.bias"]
"""
group_hierarchy = group.split(".")
for name in group_hierarchy:
network = getattr(network, name)

try:
params = network.parameters()
params_names = [
".".join([group, name]) for name, _ in network.named_parameters()
]
except AttributeError: # we already reached params
params = (param for param in [network])
params_names = [group]

return params, params_names


def _get_params_not_in_group(
network: nn.Module, group: Iterable[str]
) -> Iterator[torch.Tensor]:
"""
Finds the parameters of a neural networks that
are not in a group.
Parameters
----------
network : nn.Module
The neural network.
group : List[str]
The group of parameters.
Returns
-------
Iterator[torch.Tensor]
A generator of all the parameters that are not in the input
group.
"""
return (param[1] for param in network.named_parameters() if param[0] not in group)
120 changes: 120 additions & 0 deletions clinicadl/optim/optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from itertools import chain
from typing import Iterator, List, Tuple, Union

import torch
import torch.nn as nn


def get_params_in_groups(
network: nn.Module, groups: Union[str, List[str]]
) -> Tuple[Iterator[torch.Tensor], List[str]]:
"""
Gets the parameters of specific groups of a neural network.
Parameters
----------
network : nn.Module
The neural network.
groups : Union[str, List[str]]
The name of the group(s), e.g. a layer or a block.
If the user refers to a sub-block, the hierarchy should be
specified with "." (see examples).
If a list is passed, the function will output the parameters
of all groups mentioned together.
Returns
-------
Iterator[torch.Tensor]
An iterator that contains the parameters of the group(s).
List[str]
The name of all the parameters in the group(s).
Examples
--------
>>> net = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 1, kernel_size=3)),
("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))),
]
)
)
>>> params, params_names = get_params_in_groups(network, "final.dense1")
>>> params_names
["final.dense1.weight", "final.dense1.bias"]
>>> params, params_names = get_params_in_groups(network, ["conv1.weight", "final"])
>>> params_names
["conv1.weight", "final.dense1.weight", "final.dense1.bias"]
"""
if isinstance(groups, str):
groups = [groups]

params = iter(())
params_names = []
for group in groups:
network_ = network
group_hierarchy = group.split(".")
for name in group_hierarchy:
network_ = getattr(network_, name)

try:
params = chain(params, network_.parameters())
params_names += [
".".join([group, name]) for name, _ in network_.named_parameters()
]
except AttributeError: # we already reached params
params = chain(params, (param for param in [network_]))
params_names += [group]

return params, params_names


def get_params_not_in_groups(
network: nn.Module, groups: Union[str, List[str]]
) -> Tuple[Iterator[torch.Tensor], List[str]]:
"""
Gets the parameters not in specific groups of a neural network.
Parameters
----------
network : nn.Module
The neural network.
groups : Union[str, List[str]]
The name of the group(s), e.g. a layer or a block.
If the user refers to a sub-block, the hierarchy should be
specified with "." (see examples).
If a list is passed, the function will output the parameters
that are not in any group of that list.
Returns
-------
Iterator[torch.Tensor]
An iterator that contains the parameters not in the group(s).
List[str]
The name of all the parameters not in the group(s).
Examples
--------
>>> net = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 1, kernel_size=3)),
("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))),
]
)
)
>>> params, params_names = get_params_in_groups(network, "final")
>>> params_names
["conv1.weight", "conv1.bias"]
>>> params, params_names = get_params_in_groups(network, ["conv1.bias", "final"])
>>> params_names
["conv1.weight"]
"""
_, in_groups = get_params_in_groups(network, groups)
params = (
param[1] for param in network.named_parameters() if param[0] not in in_groups
)
params_names = list(
param[0] for param in network.named_parameters() if param[0] not in in_groups
)
return params, params_names
12 changes: 9 additions & 3 deletions tests/unittests/optim/lr_scheduler/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_get_lr_scheduler():
[
("linear1", nn.Linear(4, 3)),
("linear2", nn.Linear(3, 2)),
("linear3", nn.Linear(2, 1)),
]
)
)
Expand All @@ -29,6 +30,9 @@ def test_get_lr_scheduler():
{
"params": network.linear2.parameters(),
},
{
"params": network.linear3.parameters(),
},
],
lr=10.0,
)
Expand Down Expand Up @@ -58,7 +62,7 @@ def test_get_lr_scheduler():
assert scheduler.threshold == 1e-1
assert scheduler.threshold_mode == "rel"
assert scheduler.cooldown == 3
assert scheduler.min_lrs == [0.1, 0.01]
assert scheduler.min_lrs == [0.1, 0.01, 0.0]
assert scheduler.eps == 1e-8

assert updated_config.scheduler == "ReduceLROnPlateau"
Expand All @@ -71,12 +75,14 @@ def test_get_lr_scheduler():
assert updated_config.min_lr == {"linear2": 0.01, "linear1": 0.1}
assert updated_config.eps == 1e-8

network.add_module("linear3", nn.Linear(3, 2))
optimizer.add_param_group({"params": network.linear3.parameters()})
config.min_lr = {"ELSE": 1, "linear2": 0.01, "linear1": 0.1}
scheduler, updated_config = get_lr_scheduler(optimizer, config)
assert scheduler.min_lrs == [0.1, 0.01, 1]

config.min_lr = 1
scheduler, updated_config = get_lr_scheduler(optimizer, config)
assert scheduler.min_lrs == [1.0, 1.0, 1.0]

config = LRSchedulerConfig()
scheduler, updated_config = get_lr_scheduler(optimizer, config)
assert isinstance(scheduler, LambdaLR)
Expand Down
Loading

0 comments on commit 8e7bcd8

Please sign in to comment.