Skip to content

Commit

Permalink
correct the fastembed embedding fuction, double check in pytest it wo…
Browse files Browse the repository at this point in the history
…rks for bivec search interface
  • Loading branch information
ClaudeHu committed Dec 2, 2024
1 parent 13330c1 commit 7dc3e03
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
4 changes: 3 additions & 1 deletion geniml/search/query2vec/text2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from typing import Union

import numpy as np

# from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from fastembed import TextEmbedding

from ...const import PKG_NAME
from ...text2bednn import Vec2VecFNN
from .abstract import Query2Vec
Expand Down Expand Up @@ -39,7 +41,7 @@ def forward(self, query: str) -> np.ndarray:
:return: the embedding vector of query
"""
# embed query string
query_embedding = np.array(self.text_embedder.embed_query(query))
query_embedding = list(self.text_embedder.embed(query))[0]
if self.v2v is None:
return query_embedding
else:
Expand Down
57 changes: 47 additions & 10 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from geniml.search import BED2BEDSearchInterface, BED2Vec, Text2BEDSearchInterface, Text2Vec
from geniml.search.backends import BiVectorBackend, HNSWBackend, QdrantBackend
from geniml.search.backends.filebackend import DEP_HNSWLIB
from geniml.search.interfaces.mlfree import BiVectorSearchInterface

DATA_FOLDER_PATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "data"
Expand Down Expand Up @@ -274,7 +275,7 @@ def query_bed():

def cosine_similarity(vec1: np.array, vec2: np.array) -> float:
# Ensure the vectors have shape (100,)
assert vec1.shape == (100,) and vec2.shape == (100,), "Both vectors must have shape (100,)"
assert vec1.shape == (100,) and vec2.shape == (100,)

# Compute the dot product of the two vectors
dot_product = np.dot(vec1, vec2)
Expand Down Expand Up @@ -474,6 +475,10 @@ def test_HNSWBackend_save(filenames, bed_hnswb, bed_embeddings, temp_bed_idx_pat
"not config.getoption('--qdrant')",
reason="Only run when --qdrant is given",
)
@pytest.mark.skipif(
"not config.getoption('--huggingface')",
reason="Only run when --huggingface is given",
)
def test_BiVectorBackend(
bed_hnswb,
metadata_hnswb,
Expand All @@ -483,16 +488,9 @@ def test_BiVectorBackend(
metadata_collection,
text_embeddings,
metadata_payloads,
nl_embed_repo,
):
def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
query_vec = np.random.random(
384,
)
search_results = bivec_backend.search(
query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank
)
assert isinstance(search_results, list)
assert len(search_results) == 2
def search_result_test(search_results: Dict, rank: bool):
min_score = 100.0
max_rank = -1
for result in search_results:
Expand All @@ -517,6 +515,18 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
assert isinstance(result["payload"]["name"], str)
assert isinstance(result["payload"]["metadata"], dict)

def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
query_vec = np.random.random(
384,
)
search_results = bivec_backend.search(
query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank
)
assert isinstance(search_results, list)
assert len(search_results) == 2

search_result_test(search_results, rank)

# test QdrantBackend
bed_backend = QdrantBackend(collection=bed_collection)
# load data
Expand All @@ -528,6 +538,20 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
bivec_qd_backend = BiVectorBackend(text_backend, bed_backend)
bivec_test(bivec_qd_backend, rank=True)
bivec_test(bivec_qd_backend, rank=False)
# test QdrantBackend + Interface
bivec_qd_interface = BiVectorSearchInterface(bivec_qd_backend, nl_embed_repo)
interface_result = bivec_qd_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, rank=True
)

search_result_test(interface_result, rank=True)

interface_result = bivec_qd_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, rank=False
)

search_result_test(interface_result, rank=False)

bivec_qd_backend.metadata_backend.qd_client.delete_collection(text_backend.collection)
bivec_qd_backend.bed_backend.qd_client.delete_collection(bed_backend.collection)

Expand All @@ -536,6 +560,19 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
bivec_hnsw_backend = BiVectorBackend(metadata_hnswb, bed_hnswb)
bivec_test(bivec_hnsw_backend, dist=True, rank=True)
bivec_test(bivec_hnsw_backend, dist=True, rank=False)
# test HNSWBackend + Interface
bivec_hnsw_interface = BiVectorSearchInterface(bivec_hnsw_backend, nl_embed_repo)
interface_result = bivec_hnsw_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, distance=True, rank=True
)

search_result_test(interface_result, rank=True)

interface_result = bivec_hnsw_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, distance=True, rank=False
)

search_result_test(interface_result, rank=False)


@pytest.mark.skipif(
Expand Down

0 comments on commit 7dc3e03

Please sign in to comment.