diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index 2399842..9787747 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -76,6 +76,9 @@ def __init__( # Default search params when one is not provided. self.default_search_params = { + "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}}, "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, @@ -323,6 +326,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D total_count = len(vectors) batch_size = 1000 + wrote_ids = [] if not isinstance(self.col, Collection): raise MilvusException(message="Collection is not initialized") for i in range(0, total_count, batch_size): @@ -334,12 +338,12 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D try: # res: Collection res = self.col.insert(insert_list, timeout=None, **kwargs) - ids.extend(res.primary_keys) + wrote_ids.extend(res.primary_keys) except MilvusException as err: logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) raise err self.col.flush() - return len(ids) + return len(wrote_ids) def delete_documents(self, document_ids: List[str]) -> None: """ diff --git a/tests/test_document_store.py b/tests/test_document_store.py index cef74a7..605de20 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -26,10 +26,11 @@ def document_store(self) -> MilvusDocumentStore: ) def test_write_documents(self, document_store: DocumentStore): - document_store.write_documents( + return_value = document_store.write_documents( [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] ) assert document_store.count_documents() == 3 + assert return_value == 3 def test_delete_documents(self, document_store: DocumentStore): """