-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpredict-api-server.py
92 lines (68 loc) · 2.87 KB
/
predict-api-server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# run: gunicorn --bind=0.0.0.0:9696 --chdir=server --log-level=debug predict:app
# predict: curl -X POST -H 'Content-Type: application/json' localhost:9696/predict -d '{"customer_age":100,"gender":"F","dependent_count":2,"education_level":2,"marital_status":"married","income_category":2,"card_category":"blue","months_on_book":6,"total_relationship_count":3,"credit_limit":4000,"total_revolving_bal":2500}'
import os
from model_service import ModelService
from multiprocessing.pool import RUN
import logging
from datetime import datetime
from flask import Flask, request, jsonify
from pymongo import MongoClient
import requests
from model_loader import ModelLoader
EVIDENTLY_SERVICE_ADDRESS = os.getenv('EVIDENTLY_SERVICE', 'http://127.0.0.1:8085')
MONGODB_ADDRESS = os.getenv("MONGODB_ADDRESS", "mongodb://127.0.0.1:27018")
MONGODB_DB = os.getenv("MONGODB_DB", "prediction_service")
MONGODB_COLLECTION = os.getenv("MONGODB_COLLECTION", "data")
RUN_ID = os.getenv('RUN_ID', 'No RUN_ID provided')
log = logging.getLogger('gunicorn.error')
def get_mongo_collection(database_name, collection_name):
mongo_client = MongoClient(MONGODB_ADDRESS)
db = mongo_client.get_database(database_name)
return db.get_collection(collection_name)
mongo_collection = get_mongo_collection(MONGODB_DB, MONGODB_COLLECTION)
model, dv = ModelLoader().load_model_from_mlflow(RUN_ID)
model_service = ModelService(model, dv)
app = Flask('duration-prediction')
app.logger.handlers = log.handlers
# GET payload example:
# {
# "customer_age": 100,
# "gender": "F"",
# "dependent_count": 2,
# "education_level": 2,
# "marital_status": "married",
# "income_category": 2,
# "card_category": "blue",
# "months_on_book": 6,
# "total_relationship_count": 3,
# "credit_limit": 4000,
# "total_revolving_bal": 2500
# }
@app.route('/predict', methods=['POST'])
def predict_endpoint():
log.info(f'Request payload: {request.get_data()}')
input = request.get_json()
# log.info(f'Request json: {input}')
features = model_service.prepare_features(input)
pred = model_service.predict(features)
result = {
'churn chance': float(str(pred)),
'model_run_id': RUN_ID
}
save_to_db(input, float(pred))
send_to_evidently_service(input, float(pred))
log.info(f'Response payload: {result}')
return jsonify(result)
def save_to_db(record, prediction):
rec = record.copy()
rec['prediction'] = prediction
rec['created_at'] = datetime.now()
rec['model_run_id'] = RUN_ID
mongo_collection.insert_one(rec)
def send_to_evidently_service(record, prediction):
rec = record.copy()
rec['prediction'] = prediction
requests.post(f"{EVIDENTLY_SERVICE_ADDRESS}/iterate/capstone", json=[rec])
if __name__ == "__main__":
log.info('App starting ...')
app.run(debug=True, host='0.0.0.0', port=9696)