Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Milvus vectorDB #1171

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/concepts/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ We are working on adding a few more APIs to complete the application lifecycle.

The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also update the docs section as well under https://github.com/meta-llama/llama-stack/tree/main/docs/source/providers/vector_io/

See this PR as a reference #1195

- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)

Providers come in two flavors:
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| FAISS | Single Node |
| SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node |
| Milvus | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted |

Expand Down
3 changes: 2 additions & 1 deletion docs/source/providers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)

Providers come in two flavors:
Expand Down Expand Up @@ -55,5 +55,6 @@ vector_io/sqlite-vec
vector_io/chromadb
vector_io/pgvector
vector_io/qdrant
vector_io/milvus
vector_io/weaviate
```
31 changes: 31 additions & 0 deletions docs/source/providers/vector_io/mivus.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
orphan: true
---
# Milvus

[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.

## Features

- Easy to use
- Fully integrated with Llama Stack

## Usage

To use Milvus in your Llama Stack project, follow these steps:

1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Milvus.
3. Start storing and querying vectors.

## Installation

You can install Milvus using pymilvus:

```bash
pip install pymilvus
```
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
19 changes: 19 additions & 0 deletions llama_stack/providers/inline/vector_io/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Dict

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import MilvusVectorIOConfig


async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter

impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
20 changes: 20 additions & 0 deletions llama_stack/providers/inline/vector_io/milvus/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict

from pydantic import BaseModel

from llama_stack.schema_utils import json_schema_type


@json_schema_type
class MilvusVectorIOConfig(BaseModel):
db_path: str

@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"db_path": "{env.MILVUS_ENDPOINT}"}
18 changes: 18 additions & 0 deletions llama_stack/providers/registry/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,22 @@ def available_providers() -> List[ProviderSpec]:
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus",
pip_packages=["pymilvus"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
),
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
),
]
21 changes: 21 additions & 0 deletions llama_stack/providers/remote/vector_io/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Dict

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import MilvusVectorIOConfig


async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .milvus import MilvusVectorIOAdapter
Copy link
Contributor

@franciscojavierarceo franciscojavierarceo Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the inline and remote implementations are equivalent.

If that is, indeed, the case can you follow the pattern that's done by Chroma?

See here: https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/vector_io/chroma/__init__.py#L15

In short, they just import the remote implementation in the __init__.py and don't have an inline chromadb.py file.


assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"

impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
22 changes: 22 additions & 0 deletions llama_stack/providers/remote/vector_io/milvus/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict, Optional

from pydantic import BaseModel

from llama_stack.schema_utils import json_schema_type


@json_schema_type
class MilvusVectorIOConfig(BaseModel):
uri: str
token: Optional[str] = None
consistency_level: str = "Strong"

@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"uri": "{env.MILVUS_ENDPOINT}", "token": "{env.MILVUS_TOKEN}"}
165 changes: 165 additions & 0 deletions llama_stack/providers/remote/vector_io/milvus/milvus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging
import os
from typing import Any, Dict, List, Optional, Union

import hashlib
import uuid
from numpy.typing import NDArray
from pymilvus import MilvusClient

from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)

from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig

logger = logging.getLogger(__name__)


class MilvusIndex(EmbeddingIndex):
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
self.client = client
self.collection_name = collection_name.replace("-", "_")
self.consistency_level = consistency_level

async def delete(self):
if self.client.has_collection(self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)

async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True, consistency_level=self.consistency_level)

data = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)

data.append(
{
"chunk_id": chunk_id,
"vector": embedding,
"chunk_content": chunk.model_dump(),
}
)
try:
self.client.insert(
self.collection_name,
data=data,
)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e

async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
collection_name=self.collection_name,
data=[embedding],
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)


class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference) -> None:
self.config = config
self.cache = {}
self.client = None
self.inference_api = inference_api

async def initialize(self) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = self.config.model_dump(exclude_none=True)["db_path"]
uri = os.path.expanduser(uri)
self.client = MilvusClient(uri=uri)

async def shutdown(self) -> None:
self.client.close()

async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
)

self.cache[vector_db.identifier] = index

async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]

vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")

index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index

async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]

async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")

await index.insert_chunks(chunks)

async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")

return await index.query_chunks(query, params)

def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
1 change: 1 addition & 0 deletions tests/client-sdk/vector_io/test_vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"faiss",
# TODO: add sqlite_vec to templates
# "sqlite_vec",
# "milvus",
]


Expand Down