-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
89 lines (74 loc) · 3.15 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# This file is part of Ranking Server.
#
# Copyright (C) 2022 Vít Škrhák <[email protected]>
# Tomáš Souček <[email protected]>
#
# Ranking Server is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 2 of the License, or (at your option)
# any later version.
#
# Ranking Server is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# Ranking Server. If not, see <https://www.gnu.org/licenses/>.
import os
import torch
import numpy as np
from flask import Flask, make_response
from dotenv import load_dotenv
from sklearn.metrics.pairwise import cosine_similarity
app = Flask(__name__)
clip_function = None
# CLIP normalized features
clip_features = None
def init_clip():
import clip
model, _ = clip.load("RN50x4", device="cpu")
def fc(query):
with torch.no_grad():
tensor = clip.tokenize([query])
tensor = model.encode_text(tensor)
tensor = tensor.numpy()[0]
return tensor / np.linalg.norm(tensor)
return fc
def load_clip_features():
features = np.fromfile(os.getenv('CLIP_FEATURES'), dtype=np.float32)
dim = int(os.getenv('CLIP_DIMENSION'))
features_amount = len(features) / dim
features = features.reshape(int(features_amount), dim)
# Norm the features for faster cos sim
features = features / np.linalg.norm(features, axis=-1)[:,np.newaxis]
return features
def get_clip(query):
print(f"clip: '{query}'")
numpy_response = clip_function(query)
response = make_response(numpy_response.astype(np.float32).tobytes("C"))
response.headers.set("Content-Type", "application/octet-stream")
return response
def get_clip_results(query):
print(f"clip-results: '{query}'")
# compute clip representation of query and compute similarities with respect to the query
numpy_response = clip_function(query)
cosine_similarities = np.dot(clip_features, numpy_response)
# take only the N closest results
how_many = int(os.getenv('HOW_MANY_RESULTS'))
first_N_results = np.argsort(cosine_similarities)[::-1][:how_many]
first_N_similarities = cosine_similarities[first_N_results]
response = make_response(first_N_results.astype(np.int32).tobytes("C") +
first_N_similarities.astype(np.float32).tobytes("C"))
response.headers.set("Content-Type", "application/octet-stream")
return response
if __name__ == "__main__":
load_dotenv()
clip_features = load_clip_features()
clip_function = init_clip()
app.route("/clip/<query>", methods=["GET"])(get_clip)
app.route("/clip-results/<query>", methods=["GET"])(get_clip_results)
env_type = os.getenv('ENV')
debug = True if env_type == "debug" else False
app.debug = debug
app.run(host='0.0.0.0', port=int(os.getenv('PORT')), debug=debug)