-
Notifications
You must be signed in to change notification settings - Fork 889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add Milvus vectorDB #1171
Open
zc277584121
wants to merge
2
commits into
meta-llama:main
Choose a base branch
from
zc277584121:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+301
−2
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
--- | ||
orphan: true | ||
--- | ||
# Milvus | ||
|
||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It | ||
allows you to store and query vectors directly within a Milvus database. | ||
That means you're not limited to storing vectors in memory or in a separate service. | ||
|
||
## Features | ||
|
||
- Easy to use | ||
- Fully integrated with Llama Stack | ||
|
||
## Usage | ||
|
||
To use Milvus in your Llama Stack project, follow these steps: | ||
|
||
1. Install the necessary dependencies. | ||
2. Configure your Llama Stack project to use Milvus. | ||
3. Start storing and querying vectors. | ||
|
||
## Installation | ||
|
||
You can install Milvus using pymilvus: | ||
|
||
```bash | ||
pip install pymilvus | ||
``` | ||
## Documentation | ||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import Dict | ||
|
||
from llama_stack.providers.datatypes import Api, ProviderSpec | ||
|
||
from .config import MilvusVectorIOConfig | ||
|
||
|
||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): | ||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter | ||
|
||
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) | ||
await impl.initialize() | ||
return impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import Any, Dict | ||
|
||
from pydantic import BaseModel | ||
|
||
from llama_stack.schema_utils import json_schema_type | ||
|
||
|
||
@json_schema_type | ||
class MilvusVectorIOConfig(BaseModel): | ||
db_path: str | ||
|
||
@classmethod | ||
def sample_config(cls) -> Dict[str, Any]: | ||
return {"db_path": "{env.MILVUS_ENDPOINT}"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import Dict | ||
|
||
from llama_stack.providers.datatypes import Api, ProviderSpec | ||
|
||
from .config import MilvusVectorIOConfig | ||
|
||
|
||
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): | ||
from .milvus import MilvusVectorIOAdapter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the inline and remote implementations are equivalent. If that is, indeed, the case can you follow the pattern that's done by Chroma? In short, they just import the remote implementation in the |
||
|
||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" | ||
|
||
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) | ||
await impl.initialize() | ||
return impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
from llama_stack.schema_utils import json_schema_type | ||
|
||
|
||
@json_schema_type | ||
class MilvusVectorIOConfig(BaseModel): | ||
uri: str | ||
token: Optional[str] = None | ||
consistency_level: str = "Strong" | ||
|
||
@classmethod | ||
def sample_config(cls) -> Dict[str, Any]: | ||
return {"uri": "{env.MILVUS_ENDPOINT}", "token": "{env.MILVUS_TOKEN}"} |
165 changes: 165 additions & 0 deletions
165
llama_stack/providers/remote/vector_io/milvus/milvus.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
import logging | ||
import os | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import hashlib | ||
import uuid | ||
from numpy.typing import NDArray | ||
from pymilvus import MilvusClient | ||
|
||
from llama_stack.apis.inference import InterleavedContent | ||
from llama_stack.apis.vector_dbs import VectorDB | ||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO | ||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate | ||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig | ||
from llama_stack.providers.utils.memory.vector_store import ( | ||
EmbeddingIndex, | ||
VectorDBWithIndex, | ||
) | ||
|
||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MilvusIndex(EmbeddingIndex): | ||
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"): | ||
self.client = client | ||
self.collection_name = collection_name.replace("-", "_") | ||
self.consistency_level = consistency_level | ||
|
||
async def delete(self): | ||
if self.client.has_collection(self.collection_name): | ||
self.client.drop_collection(collection_name=self.collection_name) | ||
|
||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): | ||
assert len(chunks) == len(embeddings), ( | ||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" | ||
) | ||
if not self.client.has_collection(self.collection_name): | ||
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True, consistency_level=self.consistency_level) | ||
|
||
data = [] | ||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): | ||
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content) | ||
|
||
data.append( | ||
{ | ||
"chunk_id": chunk_id, | ||
"vector": embedding, | ||
"chunk_content": chunk.model_dump(), | ||
} | ||
) | ||
try: | ||
self.client.insert( | ||
self.collection_name, | ||
data=data, | ||
) | ||
except Exception as e: | ||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") | ||
raise e | ||
|
||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: | ||
search_res = self.client.search( | ||
collection_name=self.collection_name, | ||
data=[embedding], | ||
limit=k, | ||
output_fields=["*"], | ||
search_params={"params": {"radius": score_threshold}}, | ||
) | ||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] | ||
scores = [res["distance"] for res in search_res[0]] | ||
return QueryChunksResponse(chunks=chunks, scores=scores) | ||
|
||
|
||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): | ||
def __init__(self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference) -> None: | ||
self.config = config | ||
self.cache = {} | ||
self.client = None | ||
self.inference_api = inference_api | ||
|
||
async def initialize(self) -> None: | ||
if isinstance(self.config, RemoteMilvusVectorIOConfig): | ||
logger.info(f"Connecting to Milvus server at {self.config.uri}") | ||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) | ||
else: | ||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") | ||
uri = self.config.model_dump(exclude_none=True)["db_path"] | ||
uri = os.path.expanduser(uri) | ||
self.client = MilvusClient(uri=uri) | ||
|
||
async def shutdown(self) -> None: | ||
self.client.close() | ||
|
||
async def register_vector_db( | ||
self, | ||
vector_db: VectorDB, | ||
) -> None: | ||
if isinstance(self.config, RemoteMilvusVectorIOConfig): | ||
consistency_level = self.config.consistency_level | ||
else: | ||
consistency_level = "Strong" | ||
index = VectorDBWithIndex( | ||
vector_db=vector_db, | ||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), | ||
inference_api=self.inference_api, | ||
) | ||
|
||
self.cache[vector_db.identifier] = index | ||
|
||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: | ||
if vector_db_id in self.cache: | ||
return self.cache[vector_db_id] | ||
|
||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) | ||
if not vector_db: | ||
raise ValueError(f"Vector DB {vector_db_id} not found") | ||
|
||
index = VectorDBWithIndex( | ||
vector_db=vector_db, | ||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier), | ||
inference_api=self.inference_api, | ||
) | ||
self.cache[vector_db_id] = index | ||
return index | ||
|
||
async def unregister_vector_db(self, vector_db_id: str) -> None: | ||
if vector_db_id in self.cache: | ||
await self.cache[vector_db_id].index.delete() | ||
del self.cache[vector_db_id] | ||
|
||
async def insert_chunks( | ||
self, | ||
vector_db_id: str, | ||
chunks: List[Chunk], | ||
ttl_seconds: Optional[int] = None, | ||
) -> None: | ||
index = await self._get_and_cache_vector_db_index(vector_db_id) | ||
if not index: | ||
raise ValueError(f"Vector DB {vector_db_id} not found") | ||
|
||
await index.insert_chunks(chunks) | ||
|
||
async def query_chunks( | ||
self, | ||
vector_db_id: str, | ||
query: InterleavedContent, | ||
params: Optional[Dict[str, Any]] = None, | ||
) -> QueryChunksResponse: | ||
index = await self._get_and_cache_vector_db_index(vector_db_id) | ||
if not index: | ||
raise ValueError(f"Vector DB {vector_db_id} not found") | ||
|
||
return await index.query_chunks(query, params) | ||
|
||
def generate_chunk_id(document_id: str, chunk_text: str) -> str: | ||
"""Generate a unique chunk ID using a hash of document ID and chunk text.""" | ||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8") | ||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
"faiss", | ||
# TODO: add sqlite_vec to templates | ||
# "sqlite_vec", | ||
# "milvus", | ||
] | ||
|
||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please also update the docs section as well under https://github.com/meta-llama/llama-stack/tree/main/docs/source/providers/vector_io/
See this PR as a reference #1195