-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp_api.py
68 lines (53 loc) · 2.24 KB
/
app_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from flask import Flask, request, jsonify
import torch
import numpy as np
from market_insight.models.lstm_model import LSTM
from market_insight.prediction.predictor import Predictor
import datetime
app = Flask(__name__)
# # Load the model and scalers
# model = LSTM()
# model.load_state_dict(torch.load('lstm_model.pth'))
# model.eval()
# scalers = torch.load('scalers.pth')
# last_sequences = torch.load('last_sequences.pth')
def load_model_scaler_and_last_sequence(symbol):
model_filename = f'model/{symbol}_lstm_model.pth'
scaler_filename = f'model/{symbol}_scaler.pth'
last_sequence_filename = f'model/{symbol}_last_sequence.pth'
model = LSTM()
model.load_state_dict(torch.load(model_filename))
model.eval()
scaler = torch.load(scaler_filename)
last_sequence = torch.load(last_sequence_filename)
return model, scaler, last_sequence
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
symbol = data['symbol']
days = data.get('days', 7) # Default to 7 days if not specified
model, scaler, last_sequence = load_model_scaler_and_last_sequence(symbol)
# Use the last sequence for the symbol
start_sequence = last_sequence[-60:].tolist()
# Ensure 'days' is within a reasonable range
if not 1 <= days <= 30: # Example: limit predictions to between 1 and 30 days
return jsonify({"error": "Number of days must be between 1 and 30"}), 400
# Make prediction
predictor = Predictor(model)
predicted_highs, predicted_lows = predictor.predict_next_days(start_sequence, days)
# Rescale the predicted values to their original scale
predicted_highs_scaled = scaler.inverse_transform(np.array(predicted_highs).reshape(-1, 1)).reshape(-1)
predicted_lows_scaled = scaler.inverse_transform(np.array(predicted_lows).reshape(-1, 1)).reshape(-1)
# Generate response
response = []
today = datetime.date.today()
for i in range(days):
date = today + datetime.timedelta(days=i)
response.append({
"date": date.isoformat(),
"predicted_high": predicted_highs_scaled[i],
"predicted_low": predicted_lows_scaled[i]
})
return jsonify(response)
if __name__ == "__main__":
app.run(debug=True)