-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for open source models based on text-embeddings-inference (…
…#66) Add support for broader integration of embedding models. This update leverages the open-source embedding inference project [text-embeddings-inference](https://github.com/huggingface/text-embeddings-inference) by Hugging Face. Signed-off-by: wileyzhang <[email protected]> --------- Signed-off-by: wileyzhang <[email protected]> Co-authored-by: wiley <[email protected]> Co-authored-by: codingjaguar <[email protected]>
- Loading branch information
1 parent
e4726cd
commit 4974e2d
Showing
5 changed files
with
83 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import requests | ||
|
||
from pymilvus.model.base import BaseEmbeddingFunction | ||
|
||
|
||
class TEIEmbeddingFunction(BaseEmbeddingFunction): | ||
def __init__( | ||
self, | ||
api_url: str, | ||
dimensions: Optional[int] = None, | ||
): | ||
self.api_url = api_url + "/v1/embeddings" | ||
self._session = requests.Session() | ||
self._dim = dimensions | ||
|
||
@property | ||
def dim(self): | ||
if self._dim is None: | ||
# This works by sending a dummy message to the API to retrieve the vector dimension, | ||
# as the original API does not directly provide this information | ||
self._dim = self._call_api(["get dim"])[0].shape[0] | ||
return self._dim | ||
|
||
def encode_queries(self, queries: List[str]) -> List[np.array]: | ||
return self._call_api(queries) | ||
|
||
def encode_documents(self, documents: List[str]) -> List[np.array]: | ||
return self._call_api(documents) | ||
|
||
def __call__(self, texts: List[str]) -> List[np.array]: | ||
return self._call_api(texts) | ||
|
||
def _call_api(self, texts: List[str]): | ||
data = {"input": texts} | ||
resp = self._session.post( # type: ignore[assignment] | ||
self.api_url, | ||
json=data, | ||
).json() | ||
if "data" not in resp: | ||
raise RuntimeError(resp["message"]) | ||
|
||
embeddings = resp["data"] | ||
|
||
# Sort resulting embeddings by index | ||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return] | ||
return [np.array(result["embedding"]) for result in sorted_embeddings] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import List | ||
|
||
import requests | ||
|
||
from pymilvus.model.base import BaseRerankFunction, RerankResult | ||
|
||
|
||
class TEIRerankFunction(BaseRerankFunction): | ||
def __init__(self, api_url: str): | ||
self.api_url = api_url + "/rerank" | ||
self._session = requests.Session() | ||
|
||
def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: | ||
resp = self._session.post( # type: ignore[assignment] | ||
self.api_url, | ||
json={ | ||
"query": query, | ||
"return_text": True, | ||
"texts": documents, | ||
}, | ||
).json() | ||
if "error" in resp: | ||
raise RuntimeError(resp["error"]) | ||
|
||
results = [] | ||
for res in resp[:5]: | ||
results.append(RerankResult(text=res["text"], score=res["score"], index=res["index"])) | ||
return results |