Skip to content

Commit

Permalink
Add support for open source models based on text-embeddings-inference (
Browse files Browse the repository at this point in the history
…#66)

Add support for broader integration of embedding models. This update
leverages the open-source embedding inference project
[text-embeddings-inference](https://github.com/huggingface/text-embeddings-inference)
by Hugging Face.

Signed-off-by: wileyzhang <[email protected]>

---------

Signed-off-by: wileyzhang <[email protected]>
Co-authored-by: wiley <[email protected]>
Co-authored-by: codingjaguar <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent e4726cd commit 4974e2d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

The `milvus-model` library provides the integration with common embedding and reranker models for Milvus, a high performance open-source vector database built for AI applications. `milvus-model` lib is included as a dependency in `pymilvus`, the Python SDK of Milvus.

`milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers.
`milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers or Hugging Face [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference) .

`milvus-model` supports Python 3.8 and above.

Expand Down
4 changes: 3 additions & 1 deletion src/pymilvus/model/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pymilvus.model.dense.sentence_transformer import SentenceTransformerEmbeddingFunction
from pymilvus.model.dense.voyageai import VoyageEmbeddingFunction
from pymilvus.model.dense.jinaai import JinaEmbeddingFunction
from pymilvus.model.dense.tei import TEIEmbeddingFunction
from pymilvus.model.dense.onnx import OnnxEmbeddingFunction
from pymilvus.model.dense.cohere import CohereEmbeddingFunction
from pymilvus.model.dense.mistralai import MistralAIEmbeddingFunction
Expand All @@ -13,9 +14,10 @@
"SentenceTransformerEmbeddingFunction",
"VoyageEmbeddingFunction",
"JinaEmbeddingFunction",
"TEIEmbeddingFunction",
"OnnxEmbeddingFunction",
"CohereEmbeddingFunction",
"MistralAIEmbeddingFunction",
"NomicEmbeddingFunction",
"InstructorEmbeddingFunction"
"InstructorEmbeddingFunction",
]
49 changes: 49 additions & 0 deletions src/pymilvus/model/dense/tei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List, Optional

import numpy as np
import requests

from pymilvus.model.base import BaseEmbeddingFunction


class TEIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
api_url: str,
dimensions: Optional[int] = None,
):
self.api_url = api_url + "/v1/embeddings"
self._session = requests.Session()
self._dim = dimensions

@property
def dim(self):
if self._dim is None:
# This works by sending a dummy message to the API to retrieve the vector dimension,
# as the original API does not directly provide this information
self._dim = self._call_api(["get dim"])[0].shape[0]
return self._dim

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._call_api(queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._call_api(documents)

def __call__(self, texts: List[str]) -> List[np.array]:
return self._call_api(texts)

def _call_api(self, texts: List[str]):
data = {"input": texts}
resp = self._session.post( # type: ignore[assignment]
self.api_url,
json=data,
).json()
if "data" not in resp:
raise RuntimeError(resp["message"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return]
return [np.array(result["embedding"]) for result in sorted_embeddings]
2 changes: 2 additions & 0 deletions src/pymilvus/model/reranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from pymilvus.model.reranker.voyageai import VoyageRerankFunction
from pymilvus.model.reranker.cross_encoder import CrossEncoderRerankFunction
from pymilvus.model.reranker.jinaai import JinaRerankFunction
from pymilvus.model.reranker.tei import TEIRerankFunction

__all__ = [
"CohereRerankFunction",
"BGERerankFunction",
"VoyageRerankFunction",
"CrossEncoderRerankFunction",
"JinaRerankFunction",
"TEIRerankFunction",
]
28 changes: 28 additions & 0 deletions src/pymilvus/model/reranker/tei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List

import requests

from pymilvus.model.base import BaseRerankFunction, RerankResult


class TEIRerankFunction(BaseRerankFunction):
def __init__(self, api_url: str):
self.api_url = api_url + "/rerank"
self._session = requests.Session()

def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
resp = self._session.post( # type: ignore[assignment]
self.api_url,
json={
"query": query,
"return_text": True,
"texts": documents,
},
).json()
if "error" in resp:
raise RuntimeError(resp["error"])

results = []
for res in resp[:5]:
results.append(RerankResult(text=res["text"], score=res["score"], index=res["index"]))
return results

0 comments on commit 4974e2d

Please sign in to comment.