-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn_solver.py
73 lines (65 loc) · 2.52 KB
/
nn_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import nn
import numpy as np
from equation import SDE
import abc
import yaml
from network import Net4Y
import time
from tqdm import tqdm
import logging
from torch.utils.tensorboard import SummaryWriter
import os
from sampler import Sampler
from torch.utils.data import DataLoader
class BSDESolver:
def __init__(self,
equation,
net,
configs):
'''
net: neural network
equation: BSDE equation
train_iter: number of iterations
'''
with open(configs, 'r') as f:
config_data = yaml.safe_load(f)
for key, value in config_data['train'].items():
setattr(self, key, value)
self.equation = equation
self.net = net
self.train_iter = Sampler(self.equation, batch_size=self.batch_size, train=True)
# optimizer
if self.optimizer == 'adam':
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)
elif self.optimizer == 'sgd':
self.optimizer = torch.optim.SGD(self.net.parameters(), lr=self.lr)
else:
raise ValueError('Optimizer not supported! Please check the configs file.')
def loss(self, inputs):
y_pred = self.net(inputs)
y_mc = self.equation.terminal_condition(inputs[1])
# l2 loss
loss = torch.mean((y_pred - y_mc)**2)
return loss
def train(self):
datetime = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime())
os.mkdir(f'logs/Experiment_{datetime}')
# tensorboard writer
writer = SummaryWriter(f'logs/Experiment_{datetime}')
# create logging file with datetime and configs
logging.basicConfig(filename=f'logs/Experiment_{datetime}/{datetime}.log', level=logging.INFO)
logging.info(f'Configs: {self.__dict__}')
start_time = time.time()
for epoch in tqdm(range(self.epochs), desc='Training'):
self.optimizer.zero_grad()
inputs = next(iter(self.train_iter))
loss = self.loss(inputs)
loss.backward()
self.optimizer.step()
logging.info(f'Epoch {epoch+1}/{self.epochs}, Loss: {loss.item()}, Y0: {self.net.y_init.item()}')
writer.add_scalar('Loss/train', loss.item(), epoch+1)
writer.add_scalar('Y0/train', self.net.y_init.item(), epoch+1)
end_time = time.time()
logging.info(f'Training time: {end_time-start_time} seconds')
writer.close()