forked from seaniezhao/torch_npss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_harmonoc.py
32 lines (24 loc) · 1.02 KB
/
train_harmonoc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import hparams
from model.wavenet_model import *
from model.timbre_training import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = WaveNetModel(hparams.create_harmonic_hparams(), device).to(device)
print('model: ', model)
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())
trainer = ModelTrainer(model=model,
data_folder='data/timbre_model',
lr=0.0005,
weight_decay=0.0001,
snapshot_path='./snapshots/harmonic',
snapshot_name='harm',
snapshot_interval=2000,
device=device)
def exit_handler():
trainer.save_model()
print("exit from keyboard")
#atexit.register(exit_handler)
#epoch = trainer.load_checkpoint('/home/sean/pythonProj/torch_npss/snapshots/harmonic/best_harmonic_model_1649_2019-03-31_17-43-00')
print('start training...')
trainer.train(batch_size=32,
epochs=1650)