From 08efd2f8acaa8f6b61314bcf6965145f61900ffe Mon Sep 17 00:00:00 2001 From: Max Isom Date: Fri, 7 Mar 2025 15:31:26 -0800 Subject: [PATCH] [BUG]: correctly handle replays on local HNSW writer --- chromadb/test/api/test_invalid_update.py | 16 +++++++++ rust/segment/src/local_hnsw.rs | 46 ++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 chromadb/test/api/test_invalid_update.py diff --git a/chromadb/test/api/test_invalid_update.py b/chromadb/test/api/test_invalid_update.py new file mode 100644 index 00000000000..c05a84afdb9 --- /dev/null +++ b/chromadb/test/api/test_invalid_update.py @@ -0,0 +1,16 @@ +import numpy as np +from chromadb.api import ClientAPI +from chromadb.api.types import IncludeEnum + + +def test_invalid_update(client: ClientAPI) -> None: + collection = client.create_collection("test") + + # Update is invalid because ID does not exist + collection.update(ids=["foo"], embeddings=[[0.0, 0.0, 0.0]]) + + collection.add(ids=["foo"], embeddings=[[1.0, 1.0, 1.0]]) + result = collection.get(ids=["foo"], include=[IncludeEnum.embeddings]) + # Embeddings should be the same as what was provided to .add() + assert result["embeddings"] is not None + assert np.allclose(result["embeddings"][0], np.array([1.0, 1.0, 1.0])) diff --git a/rust/segment/src/local_hnsw.rs b/rust/segment/src/local_hnsw.rs index 12591d6ae22..0ba1bfc008a 100644 --- a/rust/segment/src/local_hnsw.rs +++ b/rust/segment/src/local_hnsw.rs @@ -109,6 +109,23 @@ impl LocalHnswSegmentReader { chroma_index::IndexUuid(segment.id.0), ) .map_err(|_| LocalHnswSegmentReaderError::HnswIndexLoadError)?; + + let current_seq_id = { + let (query, values) = Query::select() + .column(MaxSeqId::SeqId) + .from(MaxSeqId::Table) + .and_where( + Expr::col(MaxSeqId::SegmentId).eq(segment.id.to_string()), + ) + .build_sqlx(SqliteQueryBuilder); + let row = sqlx::query_with(&query, values) + .fetch_optional(sql_db.get_conn()) + .await?; + row.map(|row| row.try_get::(0)) + .transpose()? + .unwrap_or_default() + }; + // TODO(Sanket): Set allow reset appropriately. return Ok(Self { index: LocalHnswIndex { @@ -118,6 +135,7 @@ impl LocalHnswSegmentReader { index_init: true, allow_reset: false, num_elements_since_last_persist: 0, + last_seen_seq_id: current_seq_id, sync_threshold: hnsw_configuration.sync_threshold, persist_path: Some(index_folder_str.to_string()), sqlite: sql_db, @@ -158,6 +176,7 @@ impl LocalHnswSegmentReader { index_init: true, allow_reset: false, num_elements_since_last_persist: 0, + last_seen_seq_id: 0, sync_threshold: hnsw_configuration.sync_threshold, persist_path: None, sqlite: sql_db, @@ -279,6 +298,7 @@ pub struct Inner { index_init: bool, allow_reset: bool, num_elements_since_last_persist: u64, + last_seen_seq_id: u64, sync_threshold: usize, persist_path: Option, sqlite: SqliteDb, @@ -425,6 +445,23 @@ impl LocalHnswSegmentWriter { chroma_index::IndexUuid(segment.id.0), ) .map_err(|_| LocalHnswSegmentWriterError::HnswIndexLoadError)?; + + let current_seq_id = { + let (query, values) = Query::select() + .column(MaxSeqId::SeqId) + .from(MaxSeqId::Table) + .and_where( + Expr::col(MaxSeqId::SegmentId).eq(segment.id.to_string()), + ) + .build_sqlx(SqliteQueryBuilder); + let row = sqlx::query_with(&query, values) + .fetch_optional(sql_db.get_conn()) + .await?; + row.map(|row| row.try_get::(0)) + .transpose()? + .unwrap_or_default() + }; + // TODO(Sanket): Set allow reset appropriately. return Ok(Self { index: LocalHnswIndex { @@ -434,6 +471,7 @@ impl LocalHnswSegmentWriter { index_init: true, allow_reset: false, num_elements_since_last_persist: 0, + last_seen_seq_id: current_seq_id, sync_threshold: hnsw_configuration.sync_threshold, persist_path: Some(index_folder_str.to_string()), sqlite: sql_db, @@ -468,6 +506,7 @@ impl LocalHnswSegmentWriter { index_init: true, allow_reset: false, num_elements_since_last_persist: 0, + last_seen_seq_id: 0, sync_threshold: hnsw_configuration.sync_threshold, persist_path: Some(index_folder_str.to_string()), sqlite: sql_db, @@ -499,6 +538,7 @@ impl LocalHnswSegmentWriter { index_init: true, allow_reset: false, num_elements_since_last_persist: 0, + last_seen_seq_id: 0, sync_threshold: hnsw_configuration.sync_threshold, persist_path: None, sqlite: sql_db, @@ -525,6 +565,10 @@ impl LocalHnswSegmentWriter { let mut hnsw_batch: HashMap> = HashMap::with_capacity(log_chunk.len()); for (log, _) in log_chunk.iter() { + if log.log_offset <= guard.last_seen_seq_id as i64 { + continue; + } + guard.num_elements_since_last_persist += 1; max_seq_id = max_seq_id.max(log.log_offset as u64); match log.record.operation { @@ -688,6 +732,8 @@ impl LocalHnswSegmentWriter { guard.num_elements_since_last_persist = 0; } + guard.last_seen_seq_id = max_seq_id; + Ok(next_label) } }