-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathunivariate_forecast.py
41 lines (28 loc) · 1.23 KB
/
univariate_forecast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from market import EquityData
from models.lstm import split, show_plot
import tensorflow as tf
tf.random.set_seed(42)
BATCH_SIZE = 256
BUFFER_SIZE = 10000
e = EquityData('data/SPY.csv', 'SPY')
x_train_uni, y_train_uni, x_val_uni, y_val_uni = split(e)
train_univariate = tf.data.Dataset.from_tensor_slices((x_train_uni, y_train_uni))
train_univariate = train_univariate.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
val_univariate = tf.data.Dataset.from_tensor_slices((x_val_uni, y_val_uni))
val_univariate = val_univariate.batch(BATCH_SIZE).repeat()
simple_lstm_model = tf.keras.models.Sequential([
tf.keras.layers.LSTM(8, input_shape=x_train_uni.shape[-2:]),
tf.keras.layers.Dense(1)
])
simple_lstm_model.compile(optimizer='adam', loss='mae')
for x, y in val_univariate.take(1):
print(simple_lstm_model.predict(x).shape)
EVALUATION_INTERVAL = 200
EPOCHS = 10
simple_lstm_model.fit(train_univariate, epochs=EPOCHS,
steps_per_epoch=EVALUATION_INTERVAL,
validation_data=val_univariate, validation_steps=50)
for x, y in val_univariate.take(3):
plot = show_plot([x[0].numpy(), y[0].numpy(),
simple_lstm_model.predict(x)[0]], 0, 'Simple LSTM model')
plot.show()