-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
123 lines (95 loc) · 3.39 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import requests
from flask import Flask, request, abort
from flask_cors import CORS
import json
from qdrant_client import QdrantClient
from qdrant_client.models import Filter
from qdrant_client.http import models
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = "neuralmind/bert-base-portuguese-cased"
QDRANT_HOST = "localhost"
QDRANT_PORT = 6333
app = Flask(__name__)
CORS(app)
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
encoder = SentenceTransformer(model_name_or_path=EMBEDDING_MODEL, device="cpu")
def get_ids(args) -> list:
"""Extracts ids from request query string"""
ids = request.args.get("ids", default="", type=str)
if ids == "":
return []
if "," in ids:
return ids.split(",")
return [ids]
def get_filters(args) -> dict:
"""Auxiliary function to extract filters from query string"""
must = [
models.FieldCondition(
key=arg.replace("filter.", ""), match=models.MatchValue(value=args.get(arg))
)
for arg in request.args
if "filter." in arg
]
return Filter(must=must)
@app.get("/<string:collection>/item")
def get_item(collection: str, seed_ids: str = None) -> dict:
"""Gets items based on given ids"""
if seed_ids is None:
seed_ids = get_ids(request.args)
results = client.scroll(
collection_name=collection,
with_payload=["item_id", "title", "description"],
with_vectors=False,
scroll_filter=Filter(
must=[
models.FieldCondition(
key="item_id", match=models.MatchAny(any=seed_ids)
),
]
),
)
return [json.loads(result.model_dump_json()) for result in results[0]]
@app.get("/<string:collection>/similars")
def get_similars(collection: str) -> dict:
lim = request.args.get("lim", default=10, type=int)
ids = get_ids(request.args)
score = request.args.get("score", default=0, type=float)
filters = get_filters(request.args)
seeds = get_item(collection, seed_ids=ids)
seed_ids = [seed["id"] for seed in seeds]
results = client.recommend(
collection_name=collection,
positive=seed_ids,
negative=None,
limit=lim,
with_payload=["item_id", "title", "description"],
with_vectors=False,
score_threshold=score,
query_filter=filters,
)
return [json.loads(result.model_dump_json()) for result in results]
@app.get("/<string:collection>/query")
def query(collection: str):
"""Query documents based on text string"""
text = request.args.get("text", default="", type=str)
lim = request.args.get("lim", default=10, type=int)
filters = get_filters(request.args)
results = client.search(
collection_name=collection,
query_vector=encoder.encode(text),
limit=lim,
with_payload=["item_id", "title", "description"],
with_vectors=False,
query_filter=filters,
)
return [json.loads(result.model_dump_json()) for result in results]
@app.get("/<string:collection>")
def info(collection: str):
"""Query documents based on text string"""
collection = client.get_collection(collection_name=collection)
return json.loads(collection.model_dump_json())
@app.get("/")
def ping():
"""Checks if connection to Qdrant is okay"""
return json.loads(client.get_collections().model_dump_json())