-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
80 lines (65 loc) · 2.34 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import tensorflow as tf
import numpy as np
class Model():
def __init__(self,
pde: callable,
net: tf.keras.Model,
optimiser: tf.keras.optimizers.Optimizer,
):
"""
Base class for neural network models.
Args:
pde (callable): The loss function with partial differential equation to solve.
net (tf.keras.Model): The neural network model.
optimiser (tf.keras.optimizers.Optimizer): The optimiser to use for training.
"""
self.pde = pde
self.net = net
self.optimiser = optimiser
def __call__(self, x: tf.Tensor):
"""
Call the neural network model.
Args:
x (tf.Tensor): The input tensor.
Returns:
The output of the neural network model.
"""
return self.net(x)
def train(self,
x: tf.Tensor,
num_epochs: int=10000,
verbose: bool = False,
):
"""
Train the model.
Args:
x (tf.Tensor): The training input tensor.
num_epochs (int, optional): The number of epochs to train for. Defaults to 10000.
verbose (bool, optional): Whether to print training progress. Defaults to False.
Returns:
losses (np.array): The training losses for each epoch.
"""
@tf.function
def train_step(x):
with tf.GradientTape(persistent=True) as tape:
loss = self._loss(x, self.net)
grads = tape.gradient(loss, self.net.trainable_variables)
self.optimiser.apply_gradients(zip(grads, self.net.trainable_variables))
return loss
losses = np.zeros(num_epochs)
for epoch in range(num_epochs):
losses[epoch] = train_step(x)
if (verbose and (epoch+1) % 1000 == 0) or (epoch == num_epochs-1):
print(f'Epoch {epoch+1} -- Training loss: {losses[epoch]}')
return losses
def test(self, x: tf.Tensor):
"""
Evaluate the model.
Args:
x (tf.Tensor): The input tensor.
Returns:
The loss for the input tensor.
"""
return self._loss(x, self.net)
def _loss(self, x, nn):
return self.pde(x, nn)