Skip to content

Commit

Permalink
cache
Browse files Browse the repository at this point in the history
  • Loading branch information
abarolo authored and abarolo committed Feb 23, 2024
1 parent a3814b6 commit 52ec06c
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 71 deletions.
11 changes: 7 additions & 4 deletions api/barriers/signals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from api.metadata.constants import TOP_PRIORITY_BARRIER_STATUS
from api.related_barriers import service as related_barrier_service
from api.related_barriers import model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -352,7 +353,7 @@ def barrier_completion_top_priority_barrier_resolved(


@receiver(pre_save, sender=Barrier)
def barrier_update_similarity_scores(sender, instance, *args, **kwargs):
def related_barrier_update_embeddings(sender, instance, *args, **kwargs):
try:
current_barrier_object = sender.objects.get(pk=instance.pk)
except sender.DoesNotExist:
Expand All @@ -363,10 +364,12 @@ def barrier_update_similarity_scores(sender, instance, *args, **kwargs):
for field in related_barrier_service.RELEVANT_BARRIER_FIELDS
)
if changed and not current_barrier_object.draft:
similarity_score_matrix = (
related_barrier_service.SimilarityScoreMatrix.retrieve_matrix()
model.db.update_barrier(
{
'id': str(current_barrier_object.id),
'barrier_corpus': model.barrier_to_corpus(current_barrier_object)
}
)
similarity_score_matrix.update_matrix(instance)


def barrier_changed_after_published(sender, instance, **kwargs):
Expand Down
164 changes: 131 additions & 33 deletions api/related_barriers/model.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,133 @@
import time
from functools import wraps
from typing import List, Dict, Optional
import logging

import numpy
import pandas
from django.core.cache import cache
from django.db.models import CharField
from django.db.models import Value as V
from django.db.models.functions import Concat
from sentence_transformers import SentenceTransformer, util


logger = logging.getLogger(__name__)


def timing(f):
@wraps(f)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = f(*args, **kwargs)
end = time.perf_counter()
total_time = end - start
logger.info(f'({__name__}): Function {f.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
return result
return wrapper


@timing
def load_transformer():
return SentenceTransformer("paraphrase-MiniLM-L3-v2")


SIMILARITY_THRESHOLD = 0.19
SIMILAR_BARRIERS_LIMIT = 5


EMBEDDINGS_CACHE_KEY = 'EMBEDDINGS_CACHE_KEY'
BARRIER_IDS_CACHE_KEY = 'BARRIER_IDS_CACHE_KEY'


class RelatedBarrierModelWarehouse:
__model: Optional[SentenceTransformer] = None
__embeddings = None
__cosine_sim = None
__redis = None

def __init__(self, data: List[Dict]):
model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
embeddings = self.__model.model.encode(
[d["barrier_corpus"] for d in data], convert_to_tensor=True
)
cosine_sim = util.cos_sim(embeddings, embeddings)
self.__model = load_transformer()

self.__model = model
self.__embeddings = embeddings
self.__cosine_sim = cosine_sim
self.__redis = None
@timing
def set_data():
barrier_ids = [str(d['id']) for d in data]
barrier_data = [d['barrier_corpus'] for d in data]
embeddings = self.__model.encode(barrier_data, convert_to_tensor=True)

def get_model(self):
return self.__model
self.set_embeddings(embeddings.numpy())
self.set_barrier_ids(barrier_ids)

if not self.get_barrier_ids() or not isinstance(self.get_embeddings(), numpy.ndarray):
set_data()

@staticmethod
def set_embeddings(embeddings):
cache.set(EMBEDDINGS_CACHE_KEY, embeddings, timeout=None)

@staticmethod
def set_barrier_ids(barrier_ids):
cache.set(BARRIER_IDS_CACHE_KEY, barrier_ids, timeout=None)

@staticmethod
def get_embeddings():
return cache.get(EMBEDDINGS_CACHE_KEY)

@staticmethod
def get_barrier_ids():
return cache.get(BARRIER_IDS_CACHE_KEY)

def get_embeddings(self):
if self.__redis:
pass
@property
def model(self):
return self.__model

@timing
def get_cosine_sim(self):
if self.__redis:
pass
return self.__model
embeddings = self.get_embeddings()
barrier_ids = self.get_barrier_ids()
return pandas.DataFrame(
util.cos_sim(embeddings, embeddings),
index=barrier_ids,
columns=barrier_ids,
)

@timing
def add_barrier(self, barrier):
barrier_ids = self.get_barrier_ids()
embeddings = self.get_embeddings()

@timing
def encode_barrier_corpus():
return self.model.encode(barrier['barrier_corpus'], convert_to_tensor=True).numpy()

new_embedding = encode_barrier_corpus()
new_embeddings = numpy.vstack([embeddings, new_embedding]) # append embedding
new_barrier_ids = barrier_ids + [barrier['id']] # append barrier_id

self.set_embeddings(new_embeddings)
self.set_barrier_ids(new_barrier_ids)

@timing
def remove_barrier(self, barrier):
embeddings = self.get_embeddings()
barrier_ids = self.get_barrier_ids()

index = None
for i in range(len(barrier_ids)):
if barrier_ids[i] == barrier['id']:
index = i
break

if index is not None:
# If the barrier exists, delete it from embeddings and barrier ids cache.
embeddings = numpy.delete(embeddings, index, axis=0)
del barrier_ids[index]

def update_model(self):
# Queued to avoid race conditions
# The truth is only the latest update, so if an update is already running,
# we don't need to schedule a new job. this way if 100 requests come in,
# the task knows it will rerun
#
# O(T) = max(0, T(update_model()))
# Check if a task exist
if self.__redis:
if self.__redis.get_waiting_task_count('update_model'):
# Task is already waiting to be run that includes this timestamp
return
self.set_embeddings(embeddings)
self.set_barrier_ids(barrier_ids)

pass
@timing
def update_barrier(self, barrier):
if barrier['id'] in db.get_barrier_ids():
self.remove_barrier(barrier)
self.add_barrier(barrier)


db: Optional[RelatedBarrierModelWarehouse] = None
Expand Down Expand Up @@ -84,3 +161,24 @@ def create_db() -> RelatedBarrierModelWarehouse:

return RelatedBarrierModelWarehouse(data)


@timing
def get_similar_barriers(barrier: Dict):
if not db:
raise Exception('Related Barrier DB not set')

if barrier['id'] not in db.get_barrier_ids():
db.add_barrier(barrier)

# db.update_barrier(barrier)

df = db.get_cosine_sim()

scores = df[barrier['id']].sort_values(ascending=False)[:SIMILAR_BARRIERS_LIMIT].drop(barrier['id'])
barrier_ids = scores[scores > SIMILARITY_THRESHOLD].index

return barrier_ids


def barrier_to_corpus(barrier):
return barrier.title + ". " + barrier.summary
41 changes: 14 additions & 27 deletions api/related_barriers/views.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,11 @@
# from django.shortcuts import get_object_or_404
# from rest_framework.decorators import api_view
# from rest_framework.response import Response
#
# from api.barriers.models import Barrier
# from api.related_barriers.serializers import BarrierRelatedListSerializer
# from api.related_barriers.service import SimilarityScoreMatrix
#
#
# @api_view(["GET"])
# def related_barriers(request, pk) -> Response:
# """
# Return a list of related barriers
# """
# barrier_object = get_object_or_404(Barrier, pk=pk)
# similarity_score_matrix = SimilarityScoreMatrix.retrieve_matrix()
# barriers = similarity_score_matrix.retrieve_similar_barriers(barrier_object)
# serializer = BarrierRelatedListSerializer(barriers, many=True)
# return Response(serializer.data)

from django.shortcuts import get_object_or_404
from rest_framework.decorators import api_view
from rest_framework.response import Response

from api.barriers.models import Barrier
from api.related_barriers.model import barrier_to_corpus, get_similar_barriers
from api.related_barriers.serializers import BarrierRelatedListSerializer
from api.related_barriers.model import db
from api.related_barriers.service import SimilarityScoreMatrix


# Use celery Queue to manage race conditions when updating

Expand All @@ -34,8 +14,15 @@ def related_barriers(request, pk) -> Response:
"""
Return a list of related barriers
"""
barrier_object = get_object_or_404(Barrier, pk=pk)
similarity_score_matrix = SimilarityScoreMatrix.retrieve_matrix()
barriers = similarity_score_matrix.retrieve_similar_barriers(barrier_object)
serializer = BarrierRelatedListSerializer(barriers, many=True)
return Response(serializer.data)

barrier = get_object_or_404(Barrier, pk=pk)
barrier = {'id': str(barrier.id), 'barrier_corpus': barrier_to_corpus(barrier)}

similar_barrier_ids = get_similar_barriers(barrier)

return Response(
BarrierRelatedListSerializer(
Barrier.objects.filter(id__in=similar_barrier_ids),
many=True
).data
)
17 changes: 10 additions & 7 deletions tests/related_barriers/test_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pytest

from tests.barriers.factories import BarrierFactory
from api.related_barriers import service


pytestmark = [pytest.mark.django_db]


def test_related_barriers():
def test_cosine_sim():
barrier1 = BarrierFactory(priority="LOW")
barrier2 = BarrierFactory(priority="MEDIUM")
barrier3 = BarrierFactory(priority="HIGH")
Expand All @@ -17,9 +15,14 @@ def test_related_barriers():

# assert 1 == 1

similarity_score_matrix = service.SimilarityScoreMatrix.create_matrix()
# similarity_score_matrix = service.SimilarityScoreMatrix.create_matrix()

from api.related_barriers import model

db = model.create_db()

print("Hello___1")
print(similarity_score_matrix)
assert len(db.get_cosine_sim()) == 3

assert 0
new_barrier = BarrierFactory(title='Test Title')
data = {'id': new_barrier.id, 'barrier_corpus': f'{new_barrier.title}. {new_barrier.summary}'}
model.add_barrier(data)

0 comments on commit 52ec06c

Please sign in to comment.