Skip to content

Commit

Permalink
ml: create ml module; add scikit spam model
Browse files Browse the repository at this point in the history
  • Loading branch information
yashlamba committed Nov 14, 2024
1 parent a470c63 commit 6b2e150
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 0 deletions.
2 changes: 2 additions & 0 deletions site/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ invenio_base.apps =
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
zenodo_rdm_stats = zenodo_rdm.stats.ext:ZenodoStats
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
invenio_base.api_apps =
zenodo_rdm_legacy = zenodo_rdm.legacy.ext:ZenodoLegacy
profiler = zenodo_rdm.profiler:Profiler
zenodo_rdm_metrics = zenodo_rdm.metrics.ext:ZenodoMetrics
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
invenio_base.api_blueprints =
zenodo_rdm_legacy = zenodo_rdm.legacy.views:blueprint
zenodo_rdm_legacy_records = zenodo_rdm.legacy.views:create_legacy_records_bp
Expand Down
7 changes: 7 additions & 0 deletions site/zenodo_rdm/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Zenodo-RDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Machine learning module."""
41 changes: 41 additions & 0 deletions site/zenodo_rdm/ml/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Zenodo-RDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Base class for ML models."""


class MLModel:
"""Base class for ML models."""

def __init__(self, version=None, **kwargs):
"""Constructor."""
self.version = version

def process(self, data, preprocess=None, postprocess=None, raise_exc=True):
"""Pipeline function to call pre/post process with predict."""
try:
preprocessor = preprocess or self.preprocess
postprocessor = postprocess or self.postprocess

preprocessed = preprocessor(data)
prediction = self.predict(preprocessed)
return postprocessor(prediction)
except Exception as e:
if raise_exc:
raise e
return None

def predict(self, data):
"""Predict method to be implemented by subclass."""
raise NotImplementedError()

def preprocess(self, data):
"""Preprocess data."""
return data

def postprocess(self, data):
"""Postprocess data."""
return data
21 changes: 21 additions & 0 deletions site/zenodo_rdm/ml/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.

"""Machine learning config."""

from .models import SpamDetectorScikit

ML_MODELS = {
"spam_scikit": SpamDetectorScikit,
}
"""Machine learning models."""

# NOTE Model URL and model host need to be formattable strings for the model name.
ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME"
ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE"
ML_KUBEFLOW_TOKEN = "CHANGE SECRET"
"""Kubeflow connection config."""
49 changes: 49 additions & 0 deletions site/zenodo_rdm/ml/ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.

"""ZenodoRDM machine learning module."""

from flask import current_app

from . import config


class ZenodoML:
"""Zenodo machine learning extension."""

def __init__(self, app=None):
"""Extension initialization."""
if app:
self.init_app(app)

@staticmethod
def init_config(app):
"""Initialize configuration."""
for k in dir(config):
if k.startswith("ML_"):
app.config.setdefault(k, getattr(config, k))

def init_app(self, app):
"""Flask application initialization."""
self.init_config(app)
app.extensions["zenodo-ml"] = self

def _parse_model_name_version(self, model):
"""Parse model name and version."""
vals = model.rsplit(":")
version = vals[1] if len(vals) > 1 else None
return vals[0], version

def models(self, model, **kwargs):
"""Return model based on model name."""
models = current_app.config.get("ML_MODELS", {})
model_name, version = self._parse_model_name_version(model)

if model_name not in models:
raise ValueError("Model not found/registered.")

return models[model_name](version=version, **kwargs)
85 changes: 85 additions & 0 deletions site/zenodo_rdm/ml/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Model definitions."""


import json
import string

import requests
from bs4 import BeautifulSoup
from flask import current_app

from .base import MLModel


class SpamDetectorScikit(MLModel):
"""Spam detection model based on Sklearn."""

MODEL_NAME = "sklearn-spam"
MAX_WORDS = 4000

def __init__(self, version, **kwargs):
"""Constructor. Makes version required."""
super().__init__(version, **kwargs)

def preprocess(self, data):
"""Preprocess data.
Parse HTML, remove punctuation and truncate to max chars.
"""
text = BeautifulSoup(data, "html.parser").get_text()
trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation))
parts = text.translate(trans_table).lower().strip().split(" ")
if len(parts) >= self.MAX_WORDS:
parts = parts[: self.MAX_WORDS]
return " ".join(parts)

def postprocess(self, data):
"""Postprocess data.
Gives spam and ham probability.
"""
result = {
"spam": data["outputs"][0]["data"][0],
"ham": data["outputs"][0]["data"][1],
}
return result

def _send_request_kubeflow(self, data):
"""Send predict request to Kubeflow."""
payload = {
"inputs": [
{
"name": "input-0",
"shape": [1],
"datatype": "BYTES",
"data": [f"{data}"],
}
]
}
model_ref = self.MODEL_NAME + "-" + self.version
url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref)
host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref)
access_token = current_app.config.get("ML_KUBEFLOW_TOKEN")
r = requests.post(
url,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"Host": host,
},
json=payload,
)
if r.status_code != 200:
raise requests.RequestException("Prediction was not successful.", request=r)
return json.loads(r.text)

def predict(self, data):
"""Get prediction from model."""
prediction = self._send_request_kubeflow(data)
return prediction
12 changes: 12 additions & 0 deletions site/zenodo_rdm/ml/proxies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Proxy objects for easier access to application objects."""

from flask import current_app
from werkzeug.local import LocalProxy

current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"])

0 comments on commit 6b2e150

Please sign in to comment.