Skip to content

Commit

Permalink
latest
Browse files Browse the repository at this point in the history
  • Loading branch information
abarolo authored and abarolo committed Feb 13, 2024
1 parent a15f405 commit a3814b6
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 1 deletion.
6 changes: 6 additions & 0 deletions api/related_barriers/apps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from django.apps import AppConfig
from api.related_barriers import model


class RelatedBarriersConfig(AppConfig):
name = "api.related_barriers"

def ready(self):
db = model.create_db()
model.set_db(database=db)

86 changes: 86 additions & 0 deletions api/related_barriers/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List, Dict, Optional

from django.db.models import CharField
from django.db.models import Value as V
from django.db.models.functions import Concat
from sentence_transformers import SentenceTransformer, util


class RelatedBarrierModelWarehouse:
__model: Optional[SentenceTransformer] = None
__embeddings = None
__cosine_sim = None
__redis = None

def __init__(self, data: List[Dict]):
model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
embeddings = self.__model.model.encode(
[d["barrier_corpus"] for d in data], convert_to_tensor=True
)
cosine_sim = util.cos_sim(embeddings, embeddings)

self.__model = model
self.__embeddings = embeddings
self.__cosine_sim = cosine_sim
self.__redis = None

def get_model(self):
return self.__model

def get_embeddings(self):
if self.__redis:
pass
return self.__model

def get_cosine_sim(self):
if self.__redis:
pass
return self.__model

def update_model(self):
# Queued to avoid race conditions
# The truth is only the latest update, so if an update is already running,
# we don't need to schedule a new job. this way if 100 requests come in,
# the task knows it will rerun
#
# O(T) = max(0, T(update_model()))
# Check if a task exist
if self.__redis:
if self.__redis.get_waiting_task_count('update_model'):
# Task is already waiting to be run that includes this timestamp
return

pass


db: Optional[RelatedBarrierModelWarehouse] = None


def set_db(database: RelatedBarrierModelWarehouse):
global db

if db:
raise Exception('DB already set, please stop db or restart application')

db = database


def get_data() -> List[Dict]:
from api.barriers.models import Barrier

return (
Barrier.objects.filter(archived=False).exclude(draft=True)
.annotate(
barrier_corpus=Concat(
"title", V(". "), "summary", output_field=CharField()
)
)
.values("id", "barrier_corpus")
)


def create_db() -> RelatedBarrierModelWarehouse:
data = get_data() # List[Dict]

return RelatedBarrierModelWarehouse(data)

24 changes: 23 additions & 1 deletion api/related_barriers/views.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
# from django.shortcuts import get_object_or_404
# from rest_framework.decorators import api_view
# from rest_framework.response import Response
#
# from api.barriers.models import Barrier
# from api.related_barriers.serializers import BarrierRelatedListSerializer
# from api.related_barriers.service import SimilarityScoreMatrix
#
#
# @api_view(["GET"])
# def related_barriers(request, pk) -> Response:
# """
# Return a list of related barriers
# """
# barrier_object = get_object_or_404(Barrier, pk=pk)
# similarity_score_matrix = SimilarityScoreMatrix.retrieve_matrix()
# barriers = similarity_score_matrix.retrieve_similar_barriers(barrier_object)
# serializer = BarrierRelatedListSerializer(barriers, many=True)
# return Response(serializer.data)

from django.shortcuts import get_object_or_404
from rest_framework.decorators import api_view
from rest_framework.response import Response

from api.barriers.models import Barrier
from api.related_barriers.serializers import BarrierRelatedListSerializer
from api.related_barriers.model import db
from api.related_barriers.service import SimilarityScoreMatrix

# Use celery Queue to manage race conditions when updating

@api_view(["GET"])
def related_barriers(request, pk) -> Response:
Expand All @@ -16,4 +38,4 @@ def related_barriers(request, pk) -> Response:
similarity_score_matrix = SimilarityScoreMatrix.retrieve_matrix()
barriers = similarity_score_matrix.retrieve_similar_barriers(barrier_object)
serializer = BarrierRelatedListSerializer(barriers, many=True)
return Response(serializer.data)
return Response(serializer.data)
25 changes: 25 additions & 0 deletions tests/related_barriers/test_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from tests.barriers.factories import BarrierFactory
from api.related_barriers import service


pytestmark = [pytest.mark.django_db]


def test_related_barriers():
barrier1 = BarrierFactory(priority="LOW")
barrier2 = BarrierFactory(priority="MEDIUM")
barrier3 = BarrierFactory(priority="HIGH")
#
# barrier1.summary = 'TEST'
# barrier1.save()

# assert 1 == 1

similarity_score_matrix = service.SimilarityScoreMatrix.create_matrix()

print("Hello___1")
print(similarity_score_matrix)

assert 0

0 comments on commit a3814b6

Please sign in to comment.