Skip to content

Commit

Permalink
Merge pull request #20 from lowener/fix-write
Browse files Browse the repository at this point in the history
Fix `write_documents` return value
  • Loading branch information
zc277584121 authored May 30, 2024
2 parents b7448d8 + 8ca5be7 commit 2964f40
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/milvus_haystack/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(

# Default search params when one is not provided.
self.default_search_params = {
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}},
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
Expand Down Expand Up @@ -323,6 +326,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
total_count = len(vectors)

batch_size = 1000
wrote_ids = []
if not isinstance(self.col, Collection):
raise MilvusException(message="Collection is not initialized")
for i in range(0, total_count, batch_size):
Expand All @@ -334,12 +338,12 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
try:
# res: Collection
res = self.col.insert(insert_list, timeout=None, **kwargs)
ids.extend(res.primary_keys)
wrote_ids.extend(res.primary_keys)
except MilvusException as err:
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise err
self.col.flush()
return len(ids)
return len(wrote_ids)

def delete_documents(self, document_ids: List[str]) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def document_store(self) -> MilvusDocumentStore:
)

def test_write_documents(self, document_store: DocumentStore):
document_store.write_documents(
return_value = document_store.write_documents(
[Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")]
)
assert document_store.count_documents() == 3
assert return_value == 3

def test_delete_documents(self, document_store: DocumentStore):
"""
Expand Down

0 comments on commit 2964f40

Please sign in to comment.