-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ml: create ml module; add scikit spam model
- Loading branch information
Showing
7 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2024 CERN. | ||
# | ||
# Zenodo-RDM is free software; you can redistribute it and/or modify | ||
# it under the terms of the MIT License; see LICENSE file for more details. | ||
"""Machine learning module.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2024 CERN. | ||
# | ||
# Zenodo-RDM is free software; you can redistribute it and/or modify | ||
# it under the terms of the MIT License; see LICENSE file for more details. | ||
"""Base class for ML models.""" | ||
|
||
|
||
class MLModel: | ||
"""Base class for ML models.""" | ||
|
||
def __init__(self, version=None, **kwargs): | ||
"""Constructor.""" | ||
self.version = version | ||
|
||
def process(self, data, preprocess=None, postprocess=None, raise_exc=True): | ||
"""Pipeline function to call pre/post process with predict.""" | ||
try: | ||
preprocessor = preprocess or self.preprocess | ||
postprocessor = postprocess or self.postprocess | ||
|
||
preprocessed = preprocessor(data) | ||
prediction = self.predict(preprocessed) | ||
return postprocessor(prediction) | ||
except Exception as e: | ||
if raise_exc: | ||
raise e | ||
return None | ||
|
||
def predict(self, data): | ||
"""Predict method to be implemented by subclass.""" | ||
raise NotImplementedError() | ||
|
||
def preprocess(self, data): | ||
"""Preprocess data.""" | ||
return data | ||
|
||
def postprocess(self, data): | ||
"""Postprocess data.""" | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2024 CERN. | ||
# | ||
# ZenodoRDM is free software; you can redistribute it and/or modify | ||
# it under the terms of the MIT License; see LICENSE file for more details. | ||
|
||
"""Machine learning config.""" | ||
|
||
from .models import SpamDetectorScikit | ||
|
||
ML_MODELS = { | ||
"spam_scikit": SpamDetectorScikit, | ||
} | ||
"""Machine learning models.""" | ||
|
||
# NOTE Model URL and model host need to be formattable strings for the model name. | ||
ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME" | ||
ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE" | ||
ML_KUBEFLOW_TOKEN = "CHANGE SECRET" | ||
"""Kubeflow connection config.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2024 CERN. | ||
# | ||
# ZenodoRDM is free software; you can redistribute it and/or modify | ||
# it under the terms of the MIT License; see LICENSE file for more details. | ||
|
||
"""ZenodoRDM machine learning module.""" | ||
|
||
from flask import current_app | ||
|
||
from . import config | ||
|
||
|
||
class ZenodoML: | ||
"""Zenodo machine learning extension.""" | ||
|
||
def __init__(self, app=None): | ||
"""Extension initialization.""" | ||
if app: | ||
self.init_app(app) | ||
|
||
@staticmethod | ||
def init_config(app): | ||
"""Initialize configuration.""" | ||
for k in dir(config): | ||
if k.startswith("ML_"): | ||
app.config.setdefault(k, getattr(config, k)) | ||
|
||
def init_app(self, app): | ||
"""Flask application initialization.""" | ||
self.init_config(app) | ||
app.extensions["zenodo-ml"] = self | ||
|
||
def _parse_model_name_version(self, model): | ||
"""Parse model name and version.""" | ||
vals = model.rsplit(":") | ||
version = vals[1] if len(vals) > 1 else None | ||
return vals[0], version | ||
|
||
def models(self, model, **kwargs): | ||
"""Return model based on model name.""" | ||
models = current_app.config.get("ML_MODELS", {}) | ||
model_name, version = self._parse_model_name_version(model) | ||
|
||
if model_name not in models: | ||
raise ValueError("Model not found/registered.") | ||
|
||
return models[model_name](version=version, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2024 CERN. | ||
# | ||
# ZenodoRDM is free software; you can redistribute it and/or modify | ||
# it under the terms of the MIT License; see LICENSE file for more details. | ||
"""Model definitions.""" | ||
|
||
|
||
import json | ||
import string | ||
|
||
import requests | ||
from bs4 import BeautifulSoup | ||
from flask import current_app | ||
|
||
from .base import MLModel | ||
|
||
|
||
class SpamDetectorScikit(MLModel): | ||
"""Spam detection model based on Sklearn.""" | ||
|
||
MODEL_NAME = "sklearn-spam" | ||
MAX_WORDS = 4000 | ||
|
||
def __init__(self, version, **kwargs): | ||
"""Constructor. Makes version required.""" | ||
super().__init__(version, **kwargs) | ||
|
||
def preprocess(self, data): | ||
"""Preprocess data. | ||
Parse HTML, remove punctuation and truncate to max chars. | ||
""" | ||
text = BeautifulSoup(data, "html.parser").get_text() | ||
trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation)) | ||
parts = text.translate(trans_table).lower().strip().split(" ") | ||
if len(parts) >= self.MAX_WORDS: | ||
parts = parts[: self.MAX_WORDS] | ||
return " ".join(parts) | ||
|
||
def postprocess(self, data): | ||
"""Postprocess data. | ||
Gives spam and ham probability. | ||
""" | ||
result = { | ||
"spam": data["outputs"][0]["data"][0], | ||
"ham": data["outputs"][0]["data"][1], | ||
} | ||
return result | ||
|
||
def _send_request_kubeflow(self, data): | ||
"""Send predict request to Kubeflow.""" | ||
payload = { | ||
"inputs": [ | ||
{ | ||
"name": "input-0", | ||
"shape": [1], | ||
"datatype": "BYTES", | ||
"data": [f"{data}"], | ||
} | ||
] | ||
} | ||
model_ref = self.MODEL_NAME + "-" + self.version | ||
url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref) | ||
host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref) | ||
access_token = current_app.config.get("ML_KUBEFLOW_TOKEN") | ||
r = requests.post( | ||
url, | ||
headers={ | ||
"Authorization": f"Bearer {access_token}", | ||
"Content-Type": "application/json", | ||
"Host": host, | ||
}, | ||
json=payload, | ||
) | ||
if r.status_code != 200: | ||
raise requests.RequestException("Prediction was not successful.", request=r) | ||
return json.loads(r.text) | ||
|
||
def predict(self, data): | ||
"""Get prediction from model.""" | ||
prediction = self._send_request_kubeflow(data) | ||
return prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2024 CERN. | ||
# | ||
# ZenodoRDM is free software; you can redistribute it and/or modify | ||
# it under the terms of the MIT License; see LICENSE file for more details. | ||
"""Proxy objects for easier access to application objects.""" | ||
|
||
from flask import current_app | ||
from werkzeug.local import LocalProxy | ||
|
||
current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"]) |