-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
164 lines (130 loc) · 5.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Helper functions:
- get_optimizer: Optimizer
- get_scheduler: Learning rate scheduler
- compare_configs: Compare configurations
- plot_loss_accuracy_over_epochs: Plot training and validation loss and
accuracy over epochs
"""
from typing import Dict, Iterator, List
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import _LRScheduler
def get_optimizer(
optimizer_config: Dict, parameters: Iterator[Parameter]
) -> optim.Optimizer:
"""Optimizer.
Args:
optimizer_config (Dict): Dictionary with optimizer configurations.
optimizer:
name (str): 'sgd', 'nesterov_sgd', 'rmsprop', 'adagrad', or
'adam'
learning_rate (float): Initial learning rate
momentum (float): Momentum factor
weight_decay (float): Weight decay (L2 penalty)
parameters (Iterator[Parameter]): Model parameters
Returns:
optim.Optimizer: Optimizer
"""
# Optimizer name, learning rate, momentum, and weight decay
optimizer_name = optimizer_config["name"]
lr = optimizer_config["learning_rate"]
momentum = optimizer_config["momentum"]
wd = optimizer_config["weight_decay"]
# Optimizer
if optimizer_name == "sgd":
return optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=wd)
if optimizer_name == "nesterov_sgd":
return optim.SGD(
parameters, lr=lr, momentum=momentum, weight_decay=wd, nesterov=True
)
if optimizer_name == "rmsprop":
return optim.RMSprop(parameters, lr=lr, momentum=momentum, weight_decay=wd)
if optimizer_name == "adagrad":
return optim.Adagrad(parameters, lr=lr, weight_decay=wd)
if optimizer_name == "adam":
return optim.Adam(parameters, lr=lr, weight_decay=wd)
def get_scheduler(
scheduler_config: Dict, optimizer: optim.Optimizer, num_epochs: int
) -> _LRScheduler:
"""Learning rate scheduler.
Args:
scheduler_config (Dict): Dictionary with scheduler configurations.
scheduler:
name (str): 'constant', 'step', 'multistep', 'exponential', or
'cosine'
kwargs (Dict): Scheduler specific key word arguments
optimizer (optim.Optimizer): Optimizer. Ex. SGD
num_epochs (int): Number of epochs
Returns:
_LRScheduler: Learning rate scheduler for optimizer
"""
# Scheduler name and kwargs
scheduler_name = scheduler_config["name"]
kwargs = scheduler_config["kwargs"]
# Scheduler
if scheduler_name == "constant":
return optim.lr_scheduler.StepLR(optimizer, num_epochs, gamma=1, **kwargs)
if scheduler_name == "step":
return optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.1, **kwargs)
if scheduler_name == "multistep":
return optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[50, 120, 200], gamma=0.1
)
if scheduler_name == "exponential":
return optim.lr_scheduler.ExponentialLR(
optimizer, (1e-3) ** (1 / num_epochs), **kwargs
)
if scheduler_name == "cosine":
return optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=0, last_epoch=-1
)
def compare_configs(config: Dict, config_checkpoint: Dict) -> bool:
"""Compare configurations.
Args:
config (Dict): Configuration for this run
config_checkpoint (Dict): Configuration for checkpoint run
Returns:
bool: Check if basic configurations are same with the checkpoint and can
resume training
"""
return (
config["model"] == config_checkpoint["model"]
and config["dataset"] == config_checkpoint["dataset"]
and config["training"]["optimizer"]["name"]
== config_checkpoint["training"]["optimizer"]["name"]
and config["training"]["scheduler"]
== config_checkpoint["training"]["scheduler"]
)
def plot_loss_accuracy_over_epochs(
epochs: List[int],
train_loss: List[float],
train_acc: List[float],
val_loss: List[float],
val_acc: List[float],
fpath: str,
):
"""Plot training and validation loss and accuracy over epochs.
Args:
epochs (List[int]): Training epochs, from start_epoch to start_epoch +
num_epochs
train_loss (List[float]): Training losses over epochs
train_acc (List[float]): Training accuracies over epochs
val_loss (List[float]): Validation losses over epochs
val_acc (List[float]): Validation accuracies over epochs
fpath (str): Png file path to save the plot
"""
fig, (ax0, ax1) = plt.subplots(2, 1, sharex=True)
ax0.plot(epochs, train_loss, label="Train")
ax0.plot(epochs, val_loss, label="Validation")
ax0.grid(True)
ax0.set_ylabel("Loss")
ax1.plot(epochs, train_acc, label="Train")
ax1.plot(epochs, val_acc, label="Validation")
ax1.grid(True)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Accuracy")
lines, labels = ax0.get_legend_handles_labels()
fig.legend(lines, labels, loc="upper right", bbox_to_anchor=(0.7, 0.45, 0.5, 0.5))
fig.tight_layout()
plt.savefig(fpath, bbox_inches="tight")