Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize code and fix bug #40

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions src/milvus_haystack/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@
from haystack.document_stores.types import DuplicatePolicy
from haystack.errors import FilterError
from haystack.utils import Secret, deserialize_secrets_inplace
from pymilvus import AnnSearchRequest, MilvusException, RRFRanker
from pymilvus import (
AnnSearchRequest,
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusClient,
MilvusException,
RRFRanker,
connections,
utility,
)
from pymilvus.client.abstract import BaseRanker
from pymilvus.client.types import LoadState
from pymilvus.orm.types import infer_dtype_bydata

from milvus_haystack.filters import parse_filters

Expand Down Expand Up @@ -122,17 +135,9 @@ def __init__(
:param replica_number: Number of replicas. Defaults to 1.
:param timeout: Timeout in seconds. Defaults to None.
"""
try:
from pymilvus import Collection, utility
except ImportError as err:
err_msg = "Could not import pymilvus python package. Please install it with `pip install pymilvus`."
raise ValueError(err_msg) from err

# Default search params when one is not provided.
self.default_search_params = {
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}},
"FLAT": {"metric_type": "L2", "params": {}},
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
Expand All @@ -142,7 +147,16 @@ def __init__(
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"SCANN": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
"GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}},
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"SPARSE_INVERTED_INDEX": {
"metric_type": "IP",
"params": {"drop_ratio_build": 0.2},
},
"SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
}

self.collection_name = collection_name
Expand All @@ -169,6 +183,9 @@ def __init__(
# Create the connection to the server
if connection_args is None:
self.connection_args = DEFAULT_MILVUS_CONNECTION
self._milvus_client = MilvusClient(
**self.connection_args,
)
self.alias = self._create_connection_alias(self.connection_args) # type: ignore[arg-type]
self.col: Optional[Collection] = None

Expand All @@ -193,6 +210,11 @@ def __init__(
)
self._dummy_value = 999.0

@property
def client(self) -> MilvusClient:
"""Get client."""
return self._milvus_client

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
Expand Down Expand Up @@ -311,8 +333,6 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:return: Number of documents written.
"""

from pymilvus import Collection, MilvusException

documents_cp = [MilvusDocumentStore._discard_invalid_meta(doc) for doc in deepcopy(documents)]
if len(documents_cp) > 0 and not isinstance(documents_cp[0], Document):
err_msg = "param 'documents' must contain a list of objects of type Document"
Expand Down Expand Up @@ -484,8 +504,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusDocumentStore":

def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections

connection_args_cp = copy.deepcopy(connection_args)
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args_cp.get("host", None)
Expand Down Expand Up @@ -568,15 +586,6 @@ def _init(
)

def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = None) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata

# Determine embedding dim
dim = len(embeddings[0])
fields = []
Expand Down Expand Up @@ -630,17 +639,13 @@ def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = Non

def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection

if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)

def _create_index(self) -> None:
"""Create an index on the collection"""
from pymilvus import Collection, MilvusException

if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
Expand Down Expand Up @@ -694,8 +699,6 @@ def _create_index(self) -> None:

def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
from pymilvus import Collection

if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
Expand All @@ -706,8 +709,6 @@ def _create_search_params(self) -> None:

def _get_index(self) -> Optional[Dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection

if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
Expand All @@ -721,9 +722,6 @@ def _load(
timeout: Optional[float] = None,
) -> None:
"""Load the collection if available."""
from pymilvus import Collection, utility
from pymilvus.client.types import LoadState

if (
isinstance(self.col, Collection)
and self._get_index() is not None
Expand Down Expand Up @@ -901,6 +899,8 @@ def _map_ip_to_similarity(ip_score: float) -> float:
"""
return (ip_score + 1) / 2.0

if not self.index_params:
return lambda x: x
metric_type = self.index_params.get("metric_type", None)
if metric_type == "L2":
return _map_l2_to_similarity
Expand Down Expand Up @@ -942,9 +942,6 @@ def _discard_invalid_meta(document: Document):
"""
Remove metadata fields with unsupported types from the document.
"""
from pymilvus import DataType
from pymilvus.orm.types import infer_dtype_bydata

if not isinstance(document, Document):
msg = f"Invalid document type: {type(document)}"
raise ValueError(msg)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore):
assert document_store_dict == expected_dict
reconstructed_document_store = MilvusDocumentStore.from_dict(document_store_dict)
for field in vars(reconstructed_document_store):
if field.startswith("__") or field == "alias":
if field.startswith("__") or field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_document_store, field) == getattr(document_store, field)
6 changes: 3 additions & 3 deletions tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__") or doc_store_field == "alias":
if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__") or doc_store_field == "alias":
if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
Expand Down Expand Up @@ -433,7 +433,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__") or doc_store_field == "alias":
if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
Expand Down
Loading