-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathtrain.py
116 lines (88 loc) · 4.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import math, os, torch, yaml
torch.multiprocessing.set_sharing_strategy('file_system')
import numpy as np
from rdkit import RDLogger
from utils.dataset import construct_loader
from utils.parsing import parse_train_args
from utils.training import train_epoch, test_epoch
from utils.utils import get_model, get_optimizer_and_scheduler, save_yaml_file
from utils.boltzmann import BoltzmannResampler
from argparse import Namespace
RDLogger.DisableLog('rdApp.*')
"""
Training procedures for both conformer generation and Botzmann generators
The hyperparameters are taken from utils/parsing.py and can be given as arguments
"""
def train(args, model, optimizer, scheduler, train_loader, val_loader):
best_val_loss = math.inf
best_epoch = 0
print("Starting training...")
for epoch in range(args.n_epochs):
train_loss, base_train_loss = train_epoch(model, train_loader, optimizer, device)
print("Epoch {}: Training Loss {} base loss {}".format(epoch, train_loss, base_train_loss))
val_loss, base_val_loss = test_epoch(model, val_loader, device)
print("Epoch {}: Validation Loss {} base loss {}".format(epoch, val_loss, base_val_loss))
if scheduler:
scheduler.step(val_loss)
if val_loss <= best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
torch.save(model.state_dict(), os.path.join(args.log_dir, 'best_model.pt'))
torch.save({
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}, os.path.join(args.log_dir, 'last_model.pt'))
print("Best Validation Loss {} on Epoch {}".format(best_val_loss, best_epoch))
def boltzmann_train(args, model, optimizer, train_loader, val_loader, resampler):
print("Starting training...")
val_ess = val_loader.dataset.resample_all(resampler, temperature=args.temp)
print(f"Initial val ESS: Mean {np.mean(val_ess):.4f} Median {np.median(val_ess):.4f}")
best_val = val_ess
for epoch in range(args.n_epochs):
if args.adjust_temp:
train_loader.dataset.boltzmann_resampler.temp = (3000 - args.temp) / (epoch + 1) + args.temp
train_loss, base_train_loss = train_epoch(model, train_loader, optimizer, device)
print("Epoch {}: Training Loss {} base loss {}".format(epoch, train_loss, base_train_loss))
if epoch % 5 == 0:
val_ess = val_loader.dataset.resample_all(resampler, temperature=args.temp)
print(f"Epoch {epoch} val ESS: Mean {np.mean(val_ess):.4f} Median {np.median(val_ess):.4f}")
if best_val > val_ess:
best_val = val_ess
torch.save(model.state_dict(), os.path.join(args.log_dir, 'best_model.pt'))
torch.save({
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}, os.path.join(args.log_dir, 'last_model.pt'))
if __name__ == '__main__':
args = parse_train_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# build model
if args.restart_dir:
with open(f'{args.restart_dir}/model_parameters.yml') as f:
args_old = Namespace(**yaml.full_load(f))
model = get_model(args_old).to(device)
state_dict = torch.load(f'{args.restart_dir}/best_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
else:
model = get_model(args).to(device)
numel = sum([p.numel() for p in model.parameters()])
# construct loader and set device
if args.boltzmann_training:
boltzmann_resampler = BoltzmannResampler(args, model)
else:
boltzmann_resampler = None
train_loader, val_loader = construct_loader(args, boltzmann_resampler=boltzmann_resampler)
# get optimizer and scheduler
optimizer, scheduler = get_optimizer_and_scheduler(args, model)
# record parameters
yaml_file_name = os.path.join(args.log_dir, 'model_parameters.yml')
save_yaml_file(yaml_file_name, args.__dict__)
args.device = device
if args.boltzmann_training:
boltzmann_train(args, model, optimizer, train_loader, val_loader, boltzmann_resampler)
else:
train(args, model, optimizer, scheduler, train_loader, val_loader)