Skip to content

Commit

Permalink
fix(open_embedding): use embedding for openai in batches as well (#53)
Browse files Browse the repository at this point in the history
* fix(open_embedding): use embedding for openai in batches as well

* fix: test cases
  • Loading branch information
ArslanSaleem authored Feb 18, 2025
1 parent 165d421 commit 4f2efb9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 21 deletions.
4 changes: 3 additions & 1 deletion backend/app/processing/file_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def process_segmentation(project_id: int, asset_id: int, asset_file_name: str):

vectorstore.add_docs(
docs=docs,
metadatas=metadatas
metadatas=metadatas,
batch_size=100
)

project_repository.update_asset_content_status(
Expand All @@ -67,6 +68,7 @@ def preprocess_file(asset_id: int):
# Get asset details from the database first
with SessionLocal() as db:
asset = project_repository.get_asset(db=db, asset_id=asset_id)

if asset is None:
logger.error(f"Asset with id {asset_id} not found in the database")
return
Expand Down
2 changes: 1 addition & 1 deletion backend/app/processing/process_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,4 @@ def vectorize_extraction_process_step(project_id: int, process_step_id: int, fil
]

# Add documents to vectorstore
vectorstore.add_docs(docs=docs, metadatas=metadatas)
vectorstore.add_docs(docs=docs, metadatas=metadatas, batch_size=100)
25 changes: 8 additions & 17 deletions backend/app/vectorstore/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,16 @@ def add_docs(
filename = metadatas[0].get('filename', 'unknown')
logger.info(f"Adding {len(docs)} sentences to the vector store for file {filename}")

# If using OpenAI embeddings, add all documents at once
if self.settings.use_openai_embeddings and self.settings.openai_api_key:
logger.info("Using OpenAI embeddings")
# Batching the document processing
batch_size = batch_size or self._batch_size

for i in range(0, len(docs), batch_size):
logger.info(f"Processing batch {i} to {i + batch_size}")
self._docs_collection.add(
documents=list(docs),
metadatas=metadatas,
ids=ids,
documents=docs[i : i + batch_size],
metadatas=metadatas[i : i + batch_size],
ids=ids[i : i + batch_size],
)
else:
logger.info("Using default embedding function")
batch_size = batch_size or self._batch_size

for i in range(0, len(docs), batch_size):
logger.info(f"Processing batch {i} to {i + batch_size}")
self._docs_collection.add(
documents=docs[i : i + batch_size],
metadatas=metadatas[i : i + batch_size],
ids=ids[i : i + batch_size],
)

return list(ids)

Expand Down
6 changes: 4 additions & 2 deletions backend/tests/processing/test_process_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def test_vectorize_extraction_process_step_single_reference(mock_chroma_db):
# Assertions
mock_vectorstore.add_docs.assert_called_once_with(
docs=expected_docs,
metadatas=expected_metadatas
metadatas=expected_metadatas,
batch_size=100
)

@patch('app.processing.process_queue.ChromaDB')
Expand Down Expand Up @@ -261,7 +262,8 @@ def test_vectorize_extraction_process_step_multiple_references_concatenation(moc
# Assertions
mock_vectorstore.add_docs.assert_called_once_with(
docs=expected_docs,
metadatas=expected_metadatas
metadatas=expected_metadatas,
batch_size=100
)

@patch('app.processing.process_queue.ChromaDB') # Replace with the correct module path
Expand Down

0 comments on commit 4f2efb9

Please sign in to comment.