From 0e23f084f119567ab740abc622d3cfa01cbbfe9c Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Sat, 8 Mar 2025 18:31:47 -0800 Subject: [PATCH 1/4] [ENH] Towards enabling spann in staging --- rust/frontend/src/frontend.rs | 41 +++- rust/index/src/spann/types.rs | 150 +++++------- rust/segment/src/distributed_spann.rs | 133 ++++++----- rust/segment/src/types.rs | 97 ++++++-- rust/types/src/api_types.rs | 4 + rust/types/src/hnsw_parameters.rs | 217 +++++++++++++++++- rust/worker/benches/spann.rs | 14 +- .../operators/spann_centers_search.rs | 21 +- .../src/execution/orchestration/compact.rs | 93 ++++++-- .../worker/src/execution/orchestration/mod.rs | 2 +- .../src/execution/orchestration/spann_knn.rs | 9 - rust/worker/src/server.rs | 76 ++++-- 12 files changed, 597 insertions(+), 260 deletions(-) diff --git a/rust/frontend/src/frontend.rs b/rust/frontend/src/frontend.rs index e478dcc5b60..da564bf5c41 100644 --- a/rust/frontend/src/frontend.rs +++ b/rust/frontend/src/frontend.rs @@ -21,7 +21,8 @@ use chroma_types::{ CreateTenantError, CreateTenantRequest, CreateTenantResponse, DeleteCollectionError, DeleteCollectionRecordsError, DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, DeleteCollectionRequest, DeleteDatabaseError, DeleteDatabaseRequest, DeleteDatabaseResponse, - DistributedHnswParameters, GetCollectionError, GetCollectionRequest, GetCollectionResponse, + DistributedHnswParameters, DistributedIndexType, DistributedIndexTypeParam, + DistributedSpannParameters, GetCollectionError, GetCollectionRequest, GetCollectionResponse, GetCollectionsError, GetDatabaseError, GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse, GetTenantError, GetTenantRequest, GetTenantResponse, HealthCheckResponse, HeartbeatError, HeartbeatResponse, Include, ListCollectionsRequest, ListCollectionsResponse, @@ -422,18 +423,36 @@ impl Frontend { let collection_id = CollectionUuid::new(); let segments = match self.executor { Executor::Distributed(_) => { - let hnsw_metadata = - Metadata::try_from(DistributedHnswParameters::try_from(&metadata)?)?; + let index_type = DistributedIndexTypeParam::try_from(&metadata)?; + let vector_segment = match index_type.index_type { + DistributedIndexType::Hnsw => { + let validated_metadata = + Metadata::try_from(DistributedHnswParameters::try_from(&metadata)?)?; + Segment { + id: SegmentUuid::new(), + r#type: SegmentType::HnswDistributed, + scope: SegmentScope::VECTOR, + collection: collection_id, + metadata: Some(validated_metadata), + file_path: Default::default(), + } + } + DistributedIndexType::Spann => { + let validated_metadata = + Metadata::try_from(DistributedSpannParameters::try_from(&metadata)?)?; + Segment { + id: SegmentUuid::new(), + r#type: SegmentType::Spann, + scope: SegmentScope::VECTOR, + collection: collection_id, + metadata: Some(validated_metadata), + file_path: Default::default(), + } + } + }; vec![ - Segment { - id: SegmentUuid::new(), - r#type: SegmentType::HnswDistributed, - scope: SegmentScope::VECTOR, - collection: collection_id, - metadata: Some(hnsw_metadata), - file_path: Default::default(), - }, + vector_segment, Segment { id: SegmentUuid::new(), r#type: SegmentType::BlockfileMetadata, diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 406668a723e..70022fcbdd4 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -9,8 +9,8 @@ use chroma_blockstore::{ }; use chroma_distance::{normalize, DistanceFunction}; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::CollectionUuid; use chroma_types::SpannPostingList; +use chroma_types::{CollectionUuid, DistributedSpannParameters}; use rand::seq::SliceRandom; use thiserror::Error; use uuid::Uuid; @@ -29,7 +29,7 @@ pub struct VersionsMapInner { pub versions_map: HashMap, } -#[allow(dead_code)] +#[derive(Clone)] // Note: Fields of this struct are public for testing. #[derive(Clone)] pub struct SpannIndexWriter { @@ -45,8 +45,8 @@ pub struct SpannIndexWriter { // Version number of each point. // TODO(Sanket): Finer grained locking for this map in future if perf is not satisfactory. pub versions_map: Arc>, - pub distance_function: DistanceFunction, pub dimensionality: usize, + pub params: DistributedSpannParameters, } // TODO(Sanket): Can compose errors whenever downstream returns Box. @@ -138,20 +138,6 @@ impl ChromaError for SpannIndexWriterError { const MAX_HEAD_OFFSET_ID: &str = "max_head_offset_id"; -// TODO(Sanket): Make these configurable. -#[allow(dead_code)] -const NUM_CENTROIDS_TO_SEARCH: u32 = 64; -#[allow(dead_code)] -const RNG_FACTOR: f32 = 1.0; -#[allow(dead_code)] -const SPLIT_THRESHOLD: usize = 100; -const NUM_SAMPLES_FOR_KMEANS: usize = 1000; -const INITIAL_LAMBDA: f32 = 100.0; -const REASSIGN_NBR_COUNT: usize = 8; -const QUERY_EPSILON: f32 = 10.0; -const MERGE_THRESHOLD: usize = 50; -const NUM_CENTERS_TO_MERGE_TO: usize = 8; - impl SpannIndexWriter { #[allow(clippy::too_many_arguments)] pub fn new( @@ -161,8 +147,8 @@ impl SpannIndexWriter { posting_list_writer: BlockfileWriter, next_head_id: u32, versions_map: VersionsMapInner, - distance_function: DistanceFunction, dimensionality: usize, + params: DistributedSpannParameters, ) -> Self { SpannIndexWriter { hnsw_index, @@ -171,8 +157,8 @@ impl SpannIndexWriter { posting_list_writer: Arc::new(tokio::sync::Mutex::new(posting_list_writer)), next_head_id: Arc::new(AtomicU32::new(next_head_id)), versions_map: Arc::new(parking_lot::RwLock::new(versions_map)), - distance_function, dimensionality, + params, } } @@ -275,14 +261,12 @@ impl SpannIndexWriter { versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, max_head_id_bf_id: Option<&Uuid>, - m: Option, - ef_construction: Option, - ef_search: Option, collection_id: &CollectionUuid, - distance_function: DistanceFunction, dimensionality: usize, blockfile_provider: &BlockfileProvider, + params: DistributedSpannParameters, ) -> Result { + let distance_function = DistanceFunction::from(params.space.clone()); // Create the HNSW index. let hnsw_index = match hnsw_id { Some(hnsw_id) => { @@ -301,9 +285,9 @@ impl SpannIndexWriter { collection_id, distance_function.clone(), dimensionality, - m.unwrap(), // Safe since caller should always provide this. - ef_construction.unwrap(), // Safe since caller should always provide this. - ef_search.unwrap(), // Safe since caller should always provide this. + params.m, + params.construction_ef, + params.search_ef, ) .await? } @@ -348,8 +332,8 @@ impl SpannIndexWriter { posting_list_writer, max_head_id, versions_map, - distance_function, dimensionality, + params, )) } @@ -360,7 +344,6 @@ impl SpannIndexWriter { *write_lock.versions_map.get(&id).unwrap() } - #[allow(dead_code)] async fn rng_query( &self, query: &[f32], @@ -368,10 +351,10 @@ impl SpannIndexWriter { rng_query( query, self.hnsw_index.clone(), - NUM_CENTROIDS_TO_SEARCH as usize, - QUERY_EPSILON, - RNG_FACTOR, - self.distance_function.clone(), + self.params.write_nprobe as usize, + self.params.write_rng_epsilon, + self.params.write_rng_factor, + self.params.space.clone().into(), true, ) .await @@ -419,11 +402,12 @@ impl SpannIndexWriter { { continue; } - let old_dist = self.distance_function.distance( + let distance_function: DistanceFunction = self.params.space.clone().into(); + let old_dist = distance_function.distance( old_head_embedding, &doc_embeddings[index * self.dimensionality..(index + 1) * self.dimensionality], ); - let new_dist = self.distance_function.distance( + let new_dist = distance_function.distance( new_head_embeddings[k].unwrap(), &doc_embeddings[index * self.dimensionality..(index + 1) * self.dimensionality], ); @@ -581,15 +565,16 @@ impl SpannIndexWriter { { continue; } - let distance_from_curr_center = self.distance_function.distance( + let distance_function: DistanceFunction = self.params.space.clone().into(); + let distance_from_curr_center = distance_function.distance( &doc_embeddings[index * self.dimensionality..(index + 1) * self.dimensionality], head_embedding, ); - let distance_from_split_center1 = self.distance_function.distance( + let distance_from_split_center1 = distance_function.distance( &doc_embeddings[index * self.dimensionality..(index + 1) * self.dimensionality], new_head_embeddings[0].unwrap(), ); - let distance_from_split_center2 = self.distance_function.distance( + let distance_from_split_center2 = distance_function.distance( &doc_embeddings[index * self.dimensionality..(index + 1) * self.dimensionality], new_head_embeddings[1].unwrap(), ); @@ -598,7 +583,7 @@ impl SpannIndexWriter { { continue; } - let distance_from_old_head = self.distance_function.distance( + let distance_from_old_head = distance_function.distance( &doc_embeddings[index * self.dimensionality..(index + 1) * self.dimensionality], old_head_embedding, ); @@ -641,9 +626,9 @@ impl SpannIndexWriter { ) .await?; // Reassign neighbors of this center if applicable. - if REASSIGN_NBR_COUNT > 0 { + if self.params.reassign_nbr_count > 0 { let (nearby_head_ids, _, nearby_head_embeddings) = self - .get_nearby_heads(old_head_embedding, REASSIGN_NBR_COUNT) + .get_nearby_heads(old_head_embedding, self.params.reassign_nbr_count as usize) .await?; for (head_idx, head_id) in nearby_head_ids.iter().enumerate() { // Skip the current split heads. @@ -663,7 +648,6 @@ impl SpannIndexWriter { Ok(()) } - #[allow(dead_code)] async fn append( &self, head_id: u32, @@ -728,7 +712,7 @@ impl SpannIndexWriter { } } // If size is within threshold, write the new posting back and return. - if up_to_date_index <= SPLIT_THRESHOLD { + if up_to_date_index <= self.params.split_threshold as usize { for idx in 0..up_to_date_index { if local_indices[idx] == idx { continue; @@ -773,9 +757,9 @@ impl SpannIndexWriter { /* k */ 2, /* first */ 0, last, - NUM_SAMPLES_FOR_KMEANS, - self.distance_function.clone(), - INITIAL_LAMBDA, + self.params.num_samples_kmeans, + self.params.space.clone().into(), + self.params.initial_lambda, ); clustering_output = cluster(&mut kmeans_input).map_err(SpannIndexWriterError::KMeansClusteringError)?; @@ -832,12 +816,12 @@ impl SpannIndexWriter { ); } let mut same_head = false; + let distance_function: DistanceFunction = self.params.space.clone().into(); for k in 0..2 { // Update the existing head. // TODO(Sanket): Need to understand what this achieves. if !same_head - && self - .distance_function + && distance_function .distance(&clustering_output.cluster_centers[k], &head_embedding) < 1e-6 { @@ -922,7 +906,6 @@ impl SpannIndexWriter { .await } - #[allow(dead_code)] async fn add_to_postings_list( &self, id: u32, @@ -990,7 +973,8 @@ impl SpannIndexWriter { let version = self.add_versions_map(id); // Normalize the embedding in case of cosine. let mut normalized_embedding = embedding.to_vec(); - if self.distance_function == DistanceFunction::Cosine { + let distance_function: DistanceFunction = self.params.space.clone().into(); + if distance_function == DistanceFunction::Cosine { normalized_embedding = normalize(embedding); } // Add to the posting list. @@ -1173,7 +1157,7 @@ impl SpannIndexWriter { .await?; source_cluster_len = doc_offset_ids.len(); // Write the PL back and return if within the merge threshold. - if source_cluster_len > MERGE_THRESHOLD { + if source_cluster_len > self.params.merge_threshold as usize { let posting_list = SpannPostingList { doc_offset_ids: &doc_offset_ids, doc_versions: &doc_versions, @@ -1188,7 +1172,7 @@ impl SpannIndexWriter { } // Find candidates for merge. let (nearest_head_ids, _, nearest_head_embeddings) = self - .get_nearby_heads(head_embedding, NUM_CENTERS_TO_MERGE_TO) + .get_nearby_heads(head_embedding, self.params.num_centers_to_merge_to as usize) .await?; for (nearest_head_id, nearest_head_embedding) in nearest_head_ids .into_iter() @@ -1214,7 +1198,7 @@ impl SpannIndexWriter { .get_up_to_date_count(&nearest_head_doc_offset_ids, &nearest_head_doc_versions) .await?; // If the total count exceeds the max posting list size then skip. - if target_cluster_len + source_cluster_len >= SPLIT_THRESHOLD { + if target_cluster_len + source_cluster_len >= self.params.split_threshold as usize { continue; } // Merge the two PLs. @@ -1271,12 +1255,13 @@ impl SpannIndexWriter { if source_cluster_len > target_cluster_len { // target_cluster points were merged to source_cluster // so they are candidates for reassignment. + let distance_function: DistanceFunction = self.params.space.clone().into(); for idx in source_cluster_len..(source_cluster_len + target_cluster_len) { - let origin_dist = self.distance_function.distance( + let origin_dist = distance_function.distance( &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], &target_embedding, ); - let new_dist = self.distance_function.distance( + let new_dist = distance_function.distance( &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], head_embedding, ); @@ -1293,12 +1278,13 @@ impl SpannIndexWriter { } else { // source_cluster points were merged to target_cluster // so they are candidates for reassignment. + let distance_function: DistanceFunction = self.params.space.clone().into(); for idx in 0..source_cluster_len { - let origin_dist = self.distance_function.distance( + let origin_dist = distance_function.distance( &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], head_embedding, ); - let new_dist = self.distance_function.distance( + let new_dist = distance_function.distance( &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], &target_embedding, ); @@ -1720,7 +1706,7 @@ mod tests { }; use chroma_cache::{new_cache_for_test, new_non_persistent_cache_for_test}; use chroma_storage::{local::LocalStorage, Storage}; - use chroma_types::{CollectionUuid, SpannPostingList}; + use chroma_types::{CollectionUuid, DistributedSpannParameters, SpannPostingList}; use rand::Rng; use tempfile::TempDir; @@ -1753,25 +1739,19 @@ mod tests { 16, rx, ); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 2; + let params = DistributedSpannParameters::default(); let writer = SpannIndexWriter::from_id( &hnsw_provider, None, None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function, dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); @@ -1949,25 +1929,19 @@ mod tests { 16, rx, ); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 2; + let params = DistributedSpannParameters::default(); let writer = SpannIndexWriter::from_id( &hnsw_provider, None, None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function, dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); @@ -2135,25 +2109,19 @@ mod tests { 16, rx, ); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 2; + let params = DistributedSpannParameters::default(); let writer = SpannIndexWriter::from_id( &hnsw_provider, None, None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function, dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); @@ -2325,25 +2293,19 @@ mod tests { 16, rx, ); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 2; + let params = DistributedSpannParameters::default(); let writer = SpannIndexWriter::from_id( &hnsw_provider, None, None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function, dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); @@ -2567,25 +2529,19 @@ mod tests { 16, rx, ); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 2; + let params = DistributedSpannParameters::default(); let writer = SpannIndexWriter::from_id( &hnsw_provider, None, None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function, dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); diff --git a/rust/segment/src/distributed_spann.rs b/rust/segment/src/distributed_spann.rs index a380e76bf60..0bf760fcf7c 100644 --- a/rust/segment/src/distributed_spann.rs +++ b/rust/segment/src/distributed_spann.rs @@ -8,13 +8,17 @@ use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::types::{ SpannIndexFlusher, SpannIndexReader, SpannIndexReaderError, SpannIndexWriterError, SpannPosting, }; +use chroma_index::spann::utils::rng_query; +use chroma_index::spann::utils::RngQueryError; use chroma_index::IndexUuid; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; -use chroma_types::DistributedHnswParameters; -use chroma_types::HnswParametersFromSegmentError; +use chroma_types::DistributedSpannParameters; +use chroma_types::DistributedSpannParametersFromSegmentError; use chroma_types::SegmentUuid; use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use std::collections::HashMap; +use std::fmt::Debug; +use std::fmt::Formatter; use thiserror::Error; use uuid::Uuid; @@ -23,10 +27,18 @@ const VERSION_MAP_PATH: &str = "version_map_path"; const POSTING_LIST_PATH: &str = "posting_list_path"; const MAX_HEAD_ID_BF_PATH: &str = "max_head_id_path"; -pub(crate) struct SpannSegmentWriter { +#[derive(Clone)] +pub struct SpannSegmentWriter { index: SpannIndexWriter, - #[allow(dead_code)] - id: SegmentUuid, + pub id: SegmentUuid, +} + +impl Debug for SpannSegmentWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DistributedSpannSegmentWriter") + .field("id", &self.id) + .finish() + } } // TODO(Sanket): Better error composability here. @@ -34,8 +46,8 @@ pub(crate) struct SpannSegmentWriter { pub enum SpannSegmentWriterError { #[error("Invalid argument")] InvalidArgument, - #[error("Could not parse HNSW configuration: {0}")] - InvalidHnswConfiguration(#[from] HnswParametersFromSegmentError), + #[error("Could not parse spann configuration: {0}")] + InvalidConfiguration(#[from] DistributedSpannParametersFromSegmentError), #[error("Error parsing index uuid from string")] IndexIdParsingError, #[error("Invalid file path for HNSW index")] @@ -63,7 +75,6 @@ impl ChromaError for SpannSegmentWriterError { match self { Self::InvalidArgument => ErrorCodes::InvalidArgument, Self::IndexIdParsingError => ErrorCodes::Internal, - Self::InvalidHnswConfiguration(_) => ErrorCodes::Internal, Self::HnswInvalidFilePath => ErrorCodes::Internal, Self::VersionMapInvalidFilePath => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, @@ -73,12 +84,12 @@ impl ChromaError for SpannSegmentWriterError { Self::SpannSegmentWriterCommitError => ErrorCodes::Internal, Self::SpannSegmentWriterFlushError => ErrorCodes::Internal, Self::SpannSegmentWriterAddRecordError(e) => e.code(), + Self::InvalidConfiguration(e) => e.code(), } } } impl SpannSegmentWriter { - #[allow(dead_code)] pub async fn from_segment( segment: &Segment, blockfile_provider: &BlockfileProvider, @@ -88,9 +99,9 @@ impl SpannSegmentWriter { if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR { return Err(SpannSegmentWriterError::InvalidArgument); } - let hnsw_configuration = DistributedHnswParameters::try_from(segment)?; + let params = DistributedSpannParameters::try_from(segment)?; - let (hnsw_id, m, ef_construction, ef_search) = match segment.file_path.get(HNSW_PATH) { + let hnsw_id = match segment.file_path.get(HNSW_PATH) { Some(hnsw_path) => match hnsw_path.first() { Some(index_id) => { let index_uuid = match Uuid::parse_str(index_id) { @@ -99,18 +110,13 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; - (Some(IndexUuid(index_uuid)), None, None, None) + Some(IndexUuid(index_uuid)) } None => { return Err(SpannSegmentWriterError::HnswInvalidFilePath); } }, - None => ( - None, - Some(hnsw_configuration.m), - Some(hnsw_configuration.construction_ef), - Some(hnsw_configuration.search_ef), - ), + None => None, }; let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) { Some(version_map_path) => match version_map_path.first() { @@ -171,13 +177,10 @@ impl SpannSegmentWriter { versions_map_id.as_ref(), posting_list_id.as_ref(), max_head_id_bf_id.as_ref(), - m, - ef_construction, - ef_search, &segment.collection, - hnsw_configuration.space.into(), dimensionality, blockfile_provider, + params, ) .await { @@ -223,7 +226,6 @@ impl SpannSegmentWriter { .map_err(SpannSegmentWriterError::SpannSegmentWriterAddRecordError) } - #[allow(dead_code)] pub async fn apply_materialized_log_chunk( &self, record_segment_reader: &Option>, @@ -263,7 +265,6 @@ impl SpannSegmentWriter { Ok(()) } - #[allow(dead_code)] pub async fn commit(self) -> Result> { let index_flusher = self .index @@ -281,14 +282,18 @@ impl SpannSegmentWriter { } pub struct SpannSegmentFlusher { - #[allow(dead_code)] - id: SegmentUuid, + pub id: SegmentUuid, index_flusher: SpannIndexFlusher, } +impl Debug for SpannSegmentFlusher { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SpannSegmentFlusher").finish() + } +} + impl SpannSegmentFlusher { - #[allow(dead_code)] - async fn flush(self) -> Result>, Box> { + pub async fn flush(self) -> Result>, Box> { let index_flusher_res = self .index_flusher .flush() @@ -321,8 +326,8 @@ impl SpannSegmentFlusher { pub enum SpannSegmentReaderError { #[error("Invalid argument")] InvalidArgument, - #[error("Could not parse HNSW configuration: {0}")] - InvalidHnswConfiguration(#[from] HnswParametersFromSegmentError), + #[error("Could not parse configuration: {0}")] + InvalidConfiguration(#[from] DistributedSpannParametersFromSegmentError), #[error("Error parsing index uuid from string")] IndexIdParsingError, #[error("Invalid file path for HNSW index")] @@ -337,6 +342,8 @@ pub enum SpannSegmentReaderError { UninitializedSegment, #[error("Error reading key")] KeyReadError, + #[error("Error doing rng {0}")] + RngError(#[from] RngQueryError), } impl ChromaError for SpannSegmentReaderError { @@ -344,13 +351,14 @@ impl ChromaError for SpannSegmentReaderError { match self { Self::InvalidArgument => ErrorCodes::InvalidArgument, Self::IndexIdParsingError => ErrorCodes::Internal, - Self::InvalidHnswConfiguration(_) => ErrorCodes::Internal, Self::HnswInvalidFilePath => ErrorCodes::Internal, Self::VersionMapInvalidFilePath => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, Self::SpannSegmentReaderCreateError => ErrorCodes::Internal, Self::UninitializedSegment => ErrorCodes::Internal, Self::KeyReadError => ErrorCodes::Internal, + Self::InvalidConfiguration(e) => e.code(), + Self::RngError(e) => e.code(), } } } @@ -364,14 +372,14 @@ pub struct SpannSegmentReaderContext { } #[derive(Clone)] -#[allow(dead_code)] pub struct SpannSegmentReader<'me> { pub index_reader: SpannIndexReader<'me>, + #[allow(dead_code)] id: SegmentUuid, + pub params: DistributedSpannParameters, } impl<'me> SpannSegmentReader<'me> { - #[allow(dead_code)] pub async fn from_segment( segment: &Segment, blockfile_provider: &BlockfileProvider, @@ -381,7 +389,7 @@ impl<'me> SpannSegmentReader<'me> { if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR { return Err(SpannSegmentReaderError::InvalidArgument); } - let hnsw_configuration = DistributedHnswParameters::try_from(segment)?; + let params = DistributedSpannParameters::try_from(segment)?; let hnsw_id = match segment.file_path.get(HNSW_PATH) { Some(hnsw_path) => match hnsw_path.first() { Some(index_id) => { @@ -438,7 +446,7 @@ impl<'me> SpannSegmentReader<'me> { hnsw_id.as_ref(), hnsw_provider, &segment.collection, - hnsw_configuration.space.into(), + params.space.clone().into(), dimensionality, posting_list_id.as_ref(), versions_map_id.as_ref(), @@ -460,6 +468,7 @@ impl<'me> SpannSegmentReader<'me> { Ok(SpannSegmentReader { index_reader, id: segment.id, + params, }) } @@ -472,6 +481,23 @@ impl<'me> SpannSegmentReader<'me> { .await .map_err(|_| SpannSegmentReaderError::KeyReadError) } + + pub async fn rng_query( + &self, + normalized_query: &[f32], + ) -> Result<(Vec, Vec, Vec>), SpannSegmentReaderError> { + let r = rng_query( + normalized_query, + self.index_reader.hnsw_index.clone(), + self.params.search_nprobe as usize, + self.params.search_rng_epsilon, + self.params.search_rng_factor, + self.params.space.clone().into(), + false, + ) + .await?; + Ok(r) + } } #[cfg(test)] @@ -483,11 +509,10 @@ mod test { provider::BlockfileProvider, }; use chroma_cache::{new_cache_for_test, new_non_persistent_cache_for_test}; - use chroma_distance::DistanceFunction; use chroma_index::{hnsw_provider::HnswIndexProvider, Index}; use chroma_storage::{local::LocalStorage, Storage}; use chroma_types::{ - Chunk, CollectionUuid, LogRecord, Metadata, MetadataValue, Operation, OperationRecord, + Chunk, CollectionUuid, DistributedSpannParameters, LogRecord, Operation, OperationRecord, SegmentUuid, SpannPostingList, }; @@ -521,20 +546,17 @@ mod test { ); let collection_id = CollectionUuid::new(); let segment_id = SegmentUuid::new(); - let mut metadata_hash_map = Metadata::new(); - metadata_hash_map.insert( - "hnsw:space".to_string(), - MetadataValue::Str("l2".to_string()), - ); - metadata_hash_map.insert("hnsw:M".to_string(), MetadataValue::Int(16)); - metadata_hash_map.insert("hnsw:construction_ef".to_string(), MetadataValue::Int(100)); - metadata_hash_map.insert("hnsw:search_ef".to_string(), MetadataValue::Int(100)); + let params = DistributedSpannParameters::default(); let mut spann_segment = chroma_types::Segment { id: segment_id, collection: collection_id, r#type: chroma_types::SegmentType::Spann, scope: chroma_types::SegmentScope::VECTOR, - metadata: Some(metadata_hash_map), + metadata: Some( + params + .try_into() + .expect("Error converting params to metadata"), + ), file_path: HashMap::new(), }; let spann_writer = SpannSegmentWriter::from_segment( @@ -619,8 +641,8 @@ mod test { .expect("Error creating spann segment writer"); assert_eq!(spann_writer.index.dimensionality, 3); assert_eq!( - spann_writer.index.distance_function, - DistanceFunction::Euclidean + spann_writer.index.params, + DistributedSpannParameters::default() ); // Next head id should be 2 since one centroid is already taken up. assert_eq!( @@ -710,20 +732,17 @@ mod test { ); let collection_id = CollectionUuid::new(); let segment_id = SegmentUuid::new(); - let mut metadata_hash_map = Metadata::new(); - metadata_hash_map.insert( - "hnsw:space".to_string(), - MetadataValue::Str("l2".to_string()), - ); - metadata_hash_map.insert("hnsw:M".to_string(), MetadataValue::Int(16)); - metadata_hash_map.insert("hnsw:construction_ef".to_string(), MetadataValue::Int(100)); - metadata_hash_map.insert("hnsw:search_ef".to_string(), MetadataValue::Int(100)); + let params = DistributedSpannParameters::default(); let mut spann_segment = chroma_types::Segment { id: segment_id, collection: collection_id, r#type: chroma_types::SegmentType::Spann, scope: chroma_types::SegmentScope::VECTOR, - metadata: Some(metadata_hash_map), + metadata: Some( + params + .try_into() + .expect("Error converting params to metadata"), + ), file_path: HashMap::new(), }; let spann_writer = SpannSegmentWriter::from_segment( diff --git a/rust/segment/src/types.rs b/rust/segment/src/types.rs index 33b7b525012..568161cb6d7 100644 --- a/rust/segment/src/types.rs +++ b/rust/segment/src/types.rs @@ -10,6 +10,8 @@ use std::sync::Arc; use thiserror::Error; use tracing::{Instrument, Span}; +use crate::distributed_spann::{SpannSegmentFlusher, SpannSegmentWriter}; + use super::blockfile_metadata::{MetadataSegmentFlusher, MetadataSegmentWriter}; use super::blockfile_record::{ ApplyMaterializedLogError, RecordSegmentFlusher, RecordSegmentReader, @@ -860,11 +862,68 @@ pub async fn materialize_logs( }) } +#[derive(Clone, Debug)] +pub enum VectorSegmentWriter { + Hnsw(Box), + Spann(SpannSegmentWriter), +} + +impl VectorSegmentWriter { + pub fn get_id(&self) -> SegmentUuid { + match self { + VectorSegmentWriter::Hnsw(writer) => writer.id, + VectorSegmentWriter::Spann(writer) => writer.id, + } + } + + pub fn get_name(&self) -> &'static str { + match self { + VectorSegmentWriter::Hnsw(_) => "DistributedHNSWSegmentWriter", + VectorSegmentWriter::Spann(_) => "SpannSegmentWriter", + } + } + + pub async fn apply_materialized_log_chunk( + &self, + record_segment_reader: &Option>, + materialized: &MaterializeLogsResult, + ) -> Result<(), ApplyMaterializedLogError> { + match self { + VectorSegmentWriter::Hnsw(writer) => { + writer + .apply_materialized_log_chunk(record_segment_reader, materialized) + .await + } + VectorSegmentWriter::Spann(writer) => { + writer + .apply_materialized_log_chunk(record_segment_reader, materialized) + .await + } + } + } + + pub async fn finish(&mut self) -> Result<(), Box> { + Ok(()) + } + + pub async fn commit(self) -> Result> { + match self { + VectorSegmentWriter::Hnsw(writer) => writer.commit().await.map(|w| { + ChromaSegmentFlusher::VectorSegment(VectorSegmentFlusher::Hnsw(Box::new(w))) + }), + VectorSegmentWriter::Spann(writer) => writer + .commit() + .await + .map(|w| ChromaSegmentFlusher::VectorSegment(VectorSegmentFlusher::Spann(w))), + } + } +} + #[derive(Clone, Debug)] pub enum ChromaSegmentWriter<'bf> { RecordSegment(RecordSegmentWriter), MetadataSegment(MetadataSegmentWriter<'bf>), - DistributedHNSWSegment(Box), + VectorSegment(VectorSegmentWriter), } impl ChromaSegmentWriter<'_> { @@ -872,7 +931,7 @@ impl ChromaSegmentWriter<'_> { match self { ChromaSegmentWriter::RecordSegment(writer) => writer.id, ChromaSegmentWriter::MetadataSegment(writer) => writer.id, - ChromaSegmentWriter::DistributedHNSWSegment(writer) => writer.id, + ChromaSegmentWriter::VectorSegment(writer) => writer.get_id(), } } @@ -880,7 +939,7 @@ impl ChromaSegmentWriter<'_> { match self { ChromaSegmentWriter::RecordSegment(_) => "RecordSegmentWriter", ChromaSegmentWriter::MetadataSegment(_) => "MetadataSegmentWriter", - ChromaSegmentWriter::DistributedHNSWSegment(_) => "DistributedHNSWSegmentWriter", + ChromaSegmentWriter::VectorSegment(writer) => writer.get_name(), } } @@ -900,7 +959,7 @@ impl ChromaSegmentWriter<'_> { .apply_materialized_log_chunk(record_segment_reader, materialized) .await } - ChromaSegmentWriter::DistributedHNSWSegment(writer) => { + ChromaSegmentWriter::VectorSegment(writer) => { writer .apply_materialized_log_chunk(record_segment_reader, materialized) .await @@ -912,7 +971,7 @@ impl ChromaSegmentWriter<'_> { match self { ChromaSegmentWriter::RecordSegment(_) => Ok(()), ChromaSegmentWriter::MetadataSegment(writer) => writer.finish().await, - ChromaSegmentWriter::DistributedHNSWSegment(_) => Ok(()), + ChromaSegmentWriter::VectorSegment(writer) => writer.finish().await, } } @@ -926,19 +985,22 @@ impl ChromaSegmentWriter<'_> { .commit() .await .map(ChromaSegmentFlusher::MetadataSegment), - ChromaSegmentWriter::DistributedHNSWSegment(writer) => writer - .commit() - .await - .map(|w| ChromaSegmentFlusher::DistributedHNSWSegment(Box::new(w))), + ChromaSegmentWriter::VectorSegment(writer) => writer.commit().await, } } } +#[derive(Debug)] +pub enum VectorSegmentFlusher { + Hnsw(Box), + Spann(SpannSegmentFlusher), +} + #[derive(Debug)] pub enum ChromaSegmentFlusher { RecordSegment(RecordSegmentFlusher), MetadataSegment(MetadataSegmentFlusher), - DistributedHNSWSegment(Box), + VectorSegment(VectorSegmentFlusher), } impl ChromaSegmentFlusher { @@ -946,7 +1008,10 @@ impl ChromaSegmentFlusher { match self { ChromaSegmentFlusher::RecordSegment(flusher) => flusher.id, ChromaSegmentFlusher::MetadataSegment(flusher) => flusher.id, - ChromaSegmentFlusher::DistributedHNSWSegment(flusher) => flusher.id, + ChromaSegmentFlusher::VectorSegment(flusher) => match flusher { + VectorSegmentFlusher::Hnsw(writer) => writer.id, + VectorSegmentFlusher::Spann(writer) => writer.id, + }, } } @@ -954,7 +1019,10 @@ impl ChromaSegmentFlusher { match self { ChromaSegmentFlusher::RecordSegment(_) => "RecordSegmentFlusher", ChromaSegmentFlusher::MetadataSegment(_) => "MetadataSegmentFlusher", - ChromaSegmentFlusher::DistributedHNSWSegment(_) => "DistributedHNSWSegmentFlusher", + ChromaSegmentFlusher::VectorSegment(flusher) => match flusher { + VectorSegmentFlusher::Hnsw(_) => "DistributedHNSWSegmentFlusher", + VectorSegmentFlusher::Spann(_) => "SpannSegmentFlusher", + }, } } @@ -962,7 +1030,10 @@ impl ChromaSegmentFlusher { match self { ChromaSegmentFlusher::RecordSegment(flusher) => flusher.flush().await, ChromaSegmentFlusher::MetadataSegment(flusher) => flusher.flush().await, - ChromaSegmentFlusher::DistributedHNSWSegment(flusher) => flusher.flush().await, + ChromaSegmentFlusher::VectorSegment(flusher) => match flusher { + VectorSegmentFlusher::Hnsw(flusher) => flusher.flush().await, + VectorSegmentFlusher::Spann(flusher) => flusher.flush().await, + }, } } } diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 4878791c7e1..16d85bd5d94 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -9,6 +9,7 @@ use crate::validators::{ use crate::Collection; use crate::CollectionConversionError; use crate::CollectionUuid; +use crate::DistributedSpannParametersFromSegmentError; use crate::HnswParametersFromSegmentError; use crate::Metadata; use crate::SegmentConversionError; @@ -564,6 +565,8 @@ pub type CreateCollectionResponse = Collection; pub enum CreateCollectionError { #[error("Invalid HNSW parameters: {0}")] InvalidHnswParameters(#[from] HnswParametersFromSegmentError), + #[error("Invalid Spann parameters: {0}")] + InvalidSpannParameters(#[from] DistributedSpannParametersFromSegmentError), #[error("Collection [{0}] already exists")] AlreadyExists(String), #[error("Database [{0}] does not exist")] @@ -580,6 +583,7 @@ impl ChromaError for CreateCollectionError { fn code(&self) -> ErrorCodes { match self { CreateCollectionError::InvalidHnswParameters(_) => ErrorCodes::InvalidArgument, + CreateCollectionError::InvalidSpannParameters(_) => ErrorCodes::InvalidArgument, CreateCollectionError::AlreadyExists(_) => ErrorCodes::AlreadyExists, CreateCollectionError::DatabaseNotFound(_) => ErrorCodes::InvalidArgument, CreateCollectionError::Get(err) => err.code(), diff --git a/rust/types/src/hnsw_parameters.rs b/rust/types/src/hnsw_parameters.rs index c5d96545823..2fe8dfd437c 100644 --- a/rust/types/src/hnsw_parameters.rs +++ b/rust/types/src/hnsw_parameters.rs @@ -22,7 +22,7 @@ impl ChromaError for HnswParametersFromSegmentError { } } -#[derive(Default, Debug, Serialize, Deserialize)] +#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub enum HnswSpace { #[default] #[serde(rename = "l2")] @@ -206,3 +206,218 @@ impl TryFrom for Metadata { Ok(parsed) } } + +fn default_search_nprobe() -> u32 { + 128 +} + +fn default_search_rng_factor() -> f32 { + 1.0 +} + +fn default_search_rng_epsilon() -> f32 { + 10.0 +} + +fn default_write_nprobe() -> u32 { + 128 +} + +fn default_write_rng_factor() -> f32 { + 1.0 +} + +fn default_write_rng_epsilon() -> f32 { + 10.0 +} + +fn default_split_threshold() -> u32 { + 100 +} + +fn default_num_samples_kmeans() -> usize { + 1000 +} + +fn default_initial_lambda() -> f32 { + 100.0 +} + +fn default_reassign_nbr_count() -> u32 { + 8 +} + +fn default_merge_threshold() -> u32 { + 50 +} + +fn default_num_centers_to_merge_to() -> u32 { + 8 +} + +fn default_construction_ef_spann() -> usize { + 200 +} + +fn default_search_ef_spann() -> usize { + 200 +} + +fn default_m_spann() -> usize { + 16 +} + +#[derive(Debug, Error)] +pub enum DistributedSpannParametersFromSegmentError { + #[error("Invalid metadata: {0}")] + InvalidMetadata(#[from] serde_json::Error), + #[error("Invalid parameters: {0}")] + InvalidParameters(#[from] validator::ValidationErrors), +} + +impl ChromaError for DistributedSpannParametersFromSegmentError { + fn code(&self) -> ErrorCodes { + match self { + DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => { + ErrorCodes::InvalidArgument + } + DistributedSpannParametersFromSegmentError::InvalidParameters(_) => { + ErrorCodes::InvalidArgument + } + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)] +pub struct DistributedSpannParameters { + #[serde(rename = "spann:search_nprobe", default = "default_search_nprobe")] + #[validate(range(min = 8))] + pub search_nprobe: u32, + #[serde( + rename = "spann:search_rng_factor", + default = "default_search_rng_factor" + )] + pub search_rng_factor: f32, + #[serde( + rename = "spann:search_rng_epsilon", + default = "default_search_rng_epsilon" + )] + pub search_rng_epsilon: f32, + #[serde( + rename = "spann:search_split_threshold", + default = "default_write_nprobe" + )] + #[validate(range(min = 8))] + pub write_nprobe: u32, + #[serde( + rename = "spann:write_rng_factor", + default = "default_write_rng_factor" + )] + pub write_rng_factor: f32, + #[serde( + rename = "spann:write_rng_epsilon", + default = "default_write_rng_epsilon" + )] + pub write_rng_epsilon: f32, + #[serde( + rename = "spann:write_split_threshold", + default = "default_split_threshold" + )] + #[validate(range(min = 50))] + pub split_threshold: u32, + #[serde( + rename = "spann:num_samples_kmeans", + default = "default_num_samples_kmeans" + )] + #[validate(range(min = 500))] + pub num_samples_kmeans: usize, + #[serde(rename = "spann:initial_lambda", default = "default_initial_lambda")] + pub initial_lambda: f32, + #[serde( + rename = "spann:reassign_nbr_count", + default = "default_reassign_nbr_count" + )] + pub reassign_nbr_count: u32, + #[serde(rename = "spann:merge_threshold", default = "default_merge_threshold")] + pub merge_threshold: u32, + #[serde( + rename = "spann:num_centers_to_merge_to", + default = "default_num_centers_to_merge_to" + )] + pub num_centers_to_merge_to: u32, + #[serde(rename = "spann:space", default)] + pub space: HnswSpace, + #[serde( + rename = "spann:construction_ef", + default = "default_construction_ef_spann" + )] + pub construction_ef: usize, + #[serde(rename = "spann:search_ef", default = "default_search_ef_spann")] + pub search_ef: usize, + #[serde(rename = "spann:M", default = "default_m_spann")] + pub m: usize, +} + +impl Default for DistributedSpannParameters { + fn default() -> Self { + serde_json::from_str("{}").unwrap() + } +} + +impl TryFrom<&Option> for DistributedSpannParameters { + type Error = DistributedSpannParametersFromSegmentError; + + fn try_from(value: &Option) -> Result { + let metadata_str = serde_json::to_string(value.as_ref().unwrap_or(&Metadata::default()))?; + let r = serde_json::from_str::(&metadata_str)?; + r.validate()?; + Ok(r) + } +} + +impl TryFrom<&Segment> for DistributedSpannParameters { + type Error = DistributedSpannParametersFromSegmentError; + + fn try_from(value: &Segment) -> Result { + DistributedSpannParameters::try_from(&value.metadata) + } +} + +impl TryFrom for Metadata { + type Error = DistributedSpannParametersFromSegmentError; + + fn try_from(value: DistributedSpannParameters) -> Result { + let metadata_str = serde_json::to_string(&value)?; + let r = serde_json::from_str::(&metadata_str)?; + Ok(r) + } +} + +fn default_index_type() -> DistributedIndexType { + DistributedIndexType::Hnsw +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DistributedIndexTypeParam { + #[serde(alias = "index_type", default = "default_index_type")] + pub index_type: DistributedIndexType, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub enum DistributedIndexType { + #[serde(rename = "hnsw")] + Hnsw, + #[serde(rename = "spann")] + Spann, +} + +impl TryFrom<&Option> for DistributedIndexTypeParam { + type Error = DistributedSpannParametersFromSegmentError; + + fn try_from(value: &Option) -> Result { + let metadata_str = serde_json::to_string(value.as_ref().unwrap_or(&Metadata::default()))?; + let r = serde_json::from_str::(&metadata_str)?; + Ok(r) + } +} diff --git a/rust/worker/benches/spann.rs b/rust/worker/benches/spann.rs index 7b3af6dadbf..01a54297e26 100644 --- a/rust/worker/benches/spann.rs +++ b/rust/worker/benches/spann.rs @@ -15,7 +15,7 @@ use chroma_index::{ }; use chroma_storage::{local::LocalStorage, Storage}; use chroma_system::Operator; -use chroma_types::CollectionUuid; +use chroma_types::{CollectionUuid, DistributedSpannParameters}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use futures::StreamExt; use rand::seq::SliceRandom; @@ -74,25 +74,19 @@ fn add_to_index_and_get_reader<'a>( 16, rx, ); - let m = 32; - let ef_construction = 100; - let ef_search = 100; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 128; + let params = DistributedSpannParameters::default(); let writer = SpannIndexWriter::from_id( &hnsw_provider, None, None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params.clone(), ) .await .expect("Error creating spann index writer"); @@ -129,7 +123,7 @@ fn add_to_index_and_get_reader<'a>( Some(&paths.hnsw_id), &hnsw_provider, &collection_id, - distance_function, + params.space.into(), dimensionality, Some(&paths.pl_id), Some(&paths.versions_map_id), diff --git a/rust/worker/src/execution/operators/spann_centers_search.rs b/rust/worker/src/execution/operators/spann_centers_search.rs index 83350b35e4a..f13058599d9 100644 --- a/rust/worker/src/execution/operators/spann_centers_search.rs +++ b/rust/worker/src/execution/operators/spann_centers_search.rs @@ -1,7 +1,5 @@ use async_trait::async_trait; -use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::spann::utils::rng_query; use chroma_segment::distributed_spann::{SpannSegmentReader, SpannSegmentReaderContext}; use chroma_system::Operator; use thiserror::Error; @@ -12,10 +10,6 @@ pub(crate) struct SpannCentersSearchInput { pub(crate) reader_context: SpannSegmentReaderContext, // Assumes that query is already normalized in case of cosine. pub(crate) normalized_query: Vec, - pub(crate) k: usize, - pub(crate) rng_epsilon: f32, - pub(crate) rng_factor: f32, - pub(crate) distance_function: DistanceFunction, } #[allow(dead_code)] @@ -68,17 +62,10 @@ impl Operator for SpannCenter .await .map_err(|_| SpannCentersSearchError::SpannSegmentReaderCreationError)?; // RNG Query. - let res = rng_query( - &input.normalized_query, - spann_reader.index_reader.hnsw_index.clone(), - input.k, - input.rng_epsilon, - input.rng_factor, - input.distance_function.clone(), - false, - ) - .await - .map_err(|_| SpannCentersSearchError::RngQueryError)?; + let res = spann_reader + .rng_query(&input.normalized_query) + .await + .map_err(|_| SpannCentersSearchError::RngQueryError)?; Ok(SpannCentersSearchOutput { center_ids: res.0 }) } } diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index c5b1ad6759c..e7a9f4361ec 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -40,9 +40,11 @@ use chroma_segment::blockfile_record::RecordSegmentReader; use chroma_segment::blockfile_record::RecordSegmentReaderCreationError; use chroma_segment::blockfile_record::RecordSegmentWriter; use chroma_segment::distributed_hnsw::DistributedHNSWSegmentWriter; +use chroma_segment::distributed_spann::SpannSegmentWriter; use chroma_segment::types::ChromaSegmentFlusher; use chroma_segment::types::ChromaSegmentWriter; use chroma_segment::types::MaterializeLogsResult; +use chroma_segment::types::VectorSegmentWriter; use chroma_sysdb::SysDb; use chroma_system::wrap; use chroma_system::ChannelError; @@ -59,6 +61,7 @@ use chroma_system::TaskResult; use chroma_types::Chunk; use chroma_types::GetCollectionsError; use chroma_types::GetSegmentsError; +use chroma_types::SegmentScope; use chroma_types::SegmentUuid; use chroma_types::{CollectionUuid, LogRecord, Segment, SegmentFlushInfo, SegmentType}; use core::panic; @@ -102,7 +105,7 @@ enum ExecutionState { pub(crate) struct CompactWriters { pub(crate) metadata: MetadataSegmentWriter<'static>, pub(crate) record: RecordSegmentWriter, - pub(crate) vector: Box, + pub(crate) vector: VectorSegmentWriter, } #[derive(Debug)] @@ -149,8 +152,10 @@ pub enum GetSegmentWritersError { RecordSegmentWriterError, #[error("Error creating Metadata Segment Writer")] MetadataSegmentWriterError, - #[error("Error creating HNSW Segment Writer")] - HnswSegmentWriterError, + #[error("Error creating Vector Segment Writer")] + VectorSegmentWriterError, + #[error("Error creating Vector Segment Writer. Unknown Vector Segment Type")] + UnknownVectorSegmentType, #[error("Collection not found")] CollectionNotFound, #[error("Error getting collection")] @@ -440,13 +445,13 @@ impl CompactOrchestrator { { self.num_uncompleted_tasks_by_segment - .entry(writers.vector.id) + .entry(writers.vector.get_id()) .and_modify(|v| { *v += 1; }) .or_insert(1); - let writer = ChromaSegmentWriter::DistributedHNSWSegment(writers.vector); + let writer = ChromaSegmentWriter::VectorSegment(writers.vector); let span = self.get_segment_writer_span(&writer); let operator = ApplyLogToSegmentWriterOperator::new(); let input = @@ -543,6 +548,24 @@ impl CompactOrchestrator { } } + async fn get_segment_from_scope( + &mut self, + segment_scope: SegmentScope, + ) -> Result { + let segments = self.get_all_segments().await?; + let segment = segments + .iter() + .find(|segment| segment.scope == segment_scope) + .cloned(); + + tracing::debug!("Found {:?} segment: {:?}", segment_scope, segment); + + match segment { + Some(segment) => Ok(segment), + None => Err(GetSegmentWritersError::NoSegmentsFound), + } + } + async fn get_segment_writers(&mut self) -> Result { // Care should be taken to use the same writers across the compaction process // Since the segment writers are stateful, we should not create new writers for each partition @@ -554,7 +577,7 @@ impl CompactOrchestrator { let record_segment = self.get_segment(SegmentType::BlockfileRecord).await?; let mt_segment = self.get_segment(SegmentType::BlockfileMetadata).await?; - let hnsw_segment = self.get_segment(SegmentType::HnswDistributed).await?; + let vector_segment = self.get_segment_from_scope(SegmentScope::VECTOR).await?; let borrowed_writers = self .writers @@ -587,7 +610,6 @@ impl CompactOrchestrator { tracing::debug!("Metadata Segment Writer created"); - // Create a hnsw segment writer let collection_res = sysdb .get_collections(Some(self.collection_id), None, None, None, None, 0) .await; @@ -603,31 +625,62 @@ impl CompactOrchestrator { return Err(GetSegmentWritersError::GetCollectionError(e)); } }; - let collection = &collection_res[0]; - if let Some(dimension) = collection.dimension { + let dim = match collection_res[0].dimension { + Some(dim) => dim, + None => { + tracing::error!( + "Error creating vector segment writer. Collection dim missing" + ); + return Err(GetSegmentWritersError::CollectionMissingDimension); + } + }; + + if vector_segment.r#type == SegmentType::HnswDistributed { + // Create a hnsw segment writer let hnsw_segment_writer = match DistributedHNSWSegmentWriter::from_segment( - &hnsw_segment, - dimension as usize, - hnsw_provider, + &vector_segment, + dim as usize, + self.hnsw_index_provider.clone(), ) .await { Ok(writer) => writer, Err(e) => { tracing::error!("Error creating HNSW segment writer: {:?}", e); - return Err(GetSegmentWritersError::HnswSegmentWriterError); + return Err(GetSegmentWritersError::VectorSegmentWriterError); + } + }; + Ok(CompactWriters { + metadata: mt_segment_writer, + record: record_segment_writer, + vector: VectorSegmentWriter::Hnsw(hnsw_segment_writer), + }) + } else if vector_segment.r#type == SegmentType::Spann { + let spann_writer = match SpannSegmentWriter::from_segment( + &vector_segment, + &blockfile_provider, + &hnsw_provider, + dim as usize, + ) + .await + { + Ok(writer) => writer, + Err(e) => { + tracing::error!("Error creating Spann segment writer: {:?}", e); + return Err(GetSegmentWritersError::VectorSegmentWriterError); } }; - return Ok(CompactWriters { + Ok(CompactWriters { metadata: mt_segment_writer, record: record_segment_writer, - vector: hnsw_segment_writer, - }); + vector: VectorSegmentWriter::Spann(spann_writer), + }) + } else { + tracing::error!("Error creating vector segment writer. Unknown segment type"); + Err(GetSegmentWritersError::UnknownVectorSegmentType) } - - Err(GetSegmentWritersError::CollectionMissingDimension) }) .await?; @@ -648,8 +701,8 @@ impl CompactOrchestrator { return Ok(ChromaSegmentWriter::RecordSegment(writers.record)); } - if writers.vector.id == segment_id { - return Ok(ChromaSegmentWriter::DistributedHNSWSegment(writers.vector)); + if writers.vector.get_id() == segment_id { + return Ok(ChromaSegmentWriter::VectorSegment(writers.vector)); } Err(GetSegmentWritersError::NoSegmentsFound) diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index fa28dc70a61..91024f40bdb 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,6 +1,6 @@ mod compact; mod count; -mod spann_knn; +pub mod spann_knn; pub(crate) use compact::*; pub(crate) use count::*; diff --git a/rust/worker/src/execution/orchestration/spann_knn.rs b/rust/worker/src/execution/orchestration/spann_knn.rs index fb97283ddbb..49c4760be6a 100644 --- a/rust/worker/src/execution/orchestration/spann_knn.rs +++ b/rust/worker/src/execution/orchestration/spann_knn.rs @@ -33,11 +33,6 @@ use crate::execution::operators::{ use super::knn_filter::{KnnError, KnnFilterOutput, KnnOutput, KnnResult}; -// TODO(Sanket): Make these configurable. -const RNG_FACTOR: f32 = 1.0; -const QUERY_EPSILON: f32 = 10.0; -const NUM_PROBE: usize = 64; - #[derive(Debug)] pub struct SpannKnnOrchestrator { // Orchestrator parameters @@ -172,10 +167,6 @@ impl Orchestrator for SpannKnnOrchestrator { SpannCentersSearchInput { reader_context, normalized_query: self.normalized_query_emb.clone(), - k: NUM_PROBE, - rng_epsilon: QUERY_EPSILON, - rng_factor: RNG_FACTOR, - distance_function: self.knn_filter_output.distance_function.clone(), }, ctx.receiver(), ); diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 77987f549c4..0518d74945d 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -16,7 +16,7 @@ use chroma_types::{ KnnBatchResult, KnnPlan, }, operator::Scan, - CollectionAndSegments, + CollectionAndSegments, SegmentType, }; use futures::{stream, StreamExt, TryStreamExt}; use tokio::signal::unix::{signal, SignalKind}; @@ -29,7 +29,7 @@ use crate::{ operators::{fetch_log::FetchLogOperator, knn_projection::KnnProjectionOperator}, orchestration::{ get::GetOrchestrator, knn::KnnOrchestrator, knn_filter::KnnFilterOrchestrator, - CountOrchestrator, + spann_knn::SpannKnnOrchestrator, CountOrchestrator, }, }, utils::convert::{from_proto_knn, to_proto_knn_batch_result}, @@ -246,6 +246,7 @@ impl WorkerServer { )?)); } + let vector_segment_type = collection_and_segments.vector_segment.r#type; let knn_filter_orchestrator = KnnFilterOrchestrator::new( self.blockfile_provider.clone(), dispatcher.clone(), @@ -264,28 +265,55 @@ impl WorkerServer { } }; - let knn_orchestrator_futures = from_proto_knn(knn)? - .into_iter() - .map(|knn| { - KnnOrchestrator::new( - self.blockfile_provider.clone(), - dispatcher.clone(), - // TODO: Make this configurable - 1000, - matching_records.clone(), - knn, - knn_projection.clone(), - ) - }) - .map(|knner| knner.run(system.clone())); - - match stream::iter(knn_orchestrator_futures) - .buffered(32) - .try_collect::>() - .await - { - Ok(results) => Ok(Response::new(to_proto_knn_batch_result(results)?)), - Err(err) => Err(Status::new(err.code().into(), err.to_string())), + if vector_segment_type == SegmentType::Spann { + tracing::info!("Running KNN on SPANN segment"); + let knn_orchestrator_futures = from_proto_knn(knn)? + .into_iter() + .map(|knn| { + SpannKnnOrchestrator::new( + self.blockfile_provider.clone(), + self.hnsw_index_provider.clone(), + dispatcher.clone(), + 1000, + matching_records.clone(), + knn.fetch as usize, + knn.embedding, + knn_projection.clone(), + ) + }) + .map(|knner| knner.run(system.clone())); + match stream::iter(knn_orchestrator_futures) + .buffered(32) + .try_collect::>() + .await + { + Ok(results) => Ok(Response::new(to_proto_knn_batch_result(results)?)), + Err(err) => Err(Status::new(err.code().into(), err.to_string())), + } + } else { + let knn_orchestrator_futures = from_proto_knn(knn)? + .into_iter() + .map(|knn| { + KnnOrchestrator::new( + self.blockfile_provider.clone(), + dispatcher.clone(), + // TODO: Make this configurable + 1000, + matching_records.clone(), + knn, + knn_projection.clone(), + ) + }) + .map(|knner| knner.run(system.clone())); + + match stream::iter(knn_orchestrator_futures) + .buffered(32) + .try_collect::>() + .await + { + Ok(results) => Ok(Response::new(to_proto_knn_batch_result(results)?)), + Err(err) => Err(Status::new(err.code().into(), err.to_string())), + } } } From 2a0ec9cf53b1a4e76d4da4ebc0d82947e5918126 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Sun, 9 Mar 2025 14:43:31 -0700 Subject: [PATCH 2/4] Rebase --- rust/index/src/spann/types.rs | 70 ++++++++++------------------------- 1 file changed, 20 insertions(+), 50 deletions(-) diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 70022fcbdd4..bafc1e87c84 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -31,7 +31,6 @@ pub struct VersionsMapInner { #[derive(Clone)] // Note: Fields of this struct are public for testing. -#[derive(Clone)] pub struct SpannIndexWriter { // HNSW index and its provider for centroid search. pub hnsw_index: HnswIndexRef, @@ -1003,7 +1002,8 @@ impl SpannIndexWriter { } // Normalize the embedding in case of cosine. let mut normalized_embedding = embedding.to_vec(); - if self.distance_function == DistanceFunction::Cosine { + let distance_function: DistanceFunction = self.params.space.clone().into(); + if distance_function == DistanceFunction::Cosine { normalized_embedding = normalize(embedding); } // Add to the posting list. @@ -1705,6 +1705,7 @@ mod tests { provider::BlockfileProvider, }; use chroma_cache::{new_cache_for_test, new_non_persistent_cache_for_test}; + use chroma_distance::DistanceFunction; use chroma_storage::{local::LocalStorage, Storage}; use chroma_types::{CollectionUuid, DistributedSpannParameters, SpannPostingList}; use rand::Rng; @@ -2780,11 +2781,9 @@ mod tests { let blockfile_provider = new_blockfile_provider_for_tests(max_block_size_bytes, storage.clone()); let hnsw_provider = new_hnsw_provider_for_tests(storage.clone(), &tmp_dir); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; + let params = DistributedSpannParameters::default(); + let distance_function = params.space.clone().into(); let dimensionality = 1000; let writer = SpannIndexWriter::from_id( &hnsw_provider, @@ -2792,13 +2791,10 @@ mod tests { None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); @@ -2870,11 +2866,9 @@ mod tests { let blockfile_provider = new_blockfile_provider_for_tests(max_block_size_bytes, storage.clone()); let hnsw_provider = new_hnsw_provider_for_tests(storage.clone(), &tmp_dir); - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; + let params = DistributedSpannParameters::default(); + let distance_function = params.space.clone().into(); let dimensionality = 1000; let writer = SpannIndexWriter::from_id( &hnsw_provider, @@ -2882,13 +2876,10 @@ mod tests { None, None, None, - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); @@ -2974,11 +2965,9 @@ mod tests { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let max_block_size_bytes = 8 * 1024 * 1024; - let m = 16; - let ef_construction = 200; - let ef_search = 200; let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; + let params = DistributedSpannParameters::default(); + let distance_function = params.space.clone().into(); let dimensionality = 1000; let mut hnsw_path = None; let mut versions_map_path = None; @@ -2996,13 +2985,10 @@ mod tests { versions_map_path.as_ref(), pl_path.as_ref(), max_bf_id_path.as_ref(), - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params.clone(), ) .await .expect("Error creating spann index writer"); @@ -3079,11 +3065,9 @@ mod tests { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let max_block_size_bytes = 8 * 1024 * 1024; - let m = 16; - let ef_construction = 200; - let ef_search = 200; + let params = DistributedSpannParameters::default(); + let distance_function = params.space.clone().into(); let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 1000; let mut hnsw_path = None; let mut versions_map_path = None; @@ -3113,13 +3097,10 @@ mod tests { versions_map_path.as_ref(), pl_path.as_ref(), max_bf_id_path.as_ref(), - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params.clone(), ) .await .expect("Error creating spann index writer"); @@ -3209,11 +3190,9 @@ mod tests { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let max_block_size_bytes = 8 * 1024 * 1024; - let m = 16; - let ef_construction = 200; - let ef_search = 200; + let params = DistributedSpannParameters::default(); + let distance_function: DistanceFunction = params.space.clone().into(); let collection_id = CollectionUuid::new(); - let distance_function = chroma_distance::DistanceFunction::Euclidean; let dimensionality = 1000; let mut hnsw_path = None; let mut versions_map_path = None; @@ -3243,13 +3222,10 @@ mod tests { versions_map_path.as_ref(), pl_path.as_ref(), max_bf_id_path.as_ref(), - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params.clone(), ) .await .expect("Error creating spann index writer"); @@ -3359,13 +3335,10 @@ mod tests { versions_map_path.as_ref(), pl_path.as_ref(), max_bf_id_path.as_ref(), - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params.clone(), ) .await .expect("Error creating spann index writer"); @@ -3477,13 +3450,10 @@ mod tests { versions_map_path.as_ref(), pl_path.as_ref(), max_bf_id_path.as_ref(), - Some(m), - Some(ef_construction), - Some(ef_search), &collection_id, - distance_function.clone(), dimensionality, &blockfile_provider, + params, ) .await .expect("Error creating spann index writer"); From 98ea5ef51fdd729045d967a6293d360bff022dbc Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Sun, 9 Mar 2025 17:37:57 -0700 Subject: [PATCH 3/4] Change commit behavior --- rust/index/src/spann/types.rs | 106 +++++++++++--------------- rust/segment/src/distributed_spann.rs | 4 +- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index bafc1e87c84..4454550cd94 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -43,7 +43,7 @@ pub struct SpannIndexWriter { pub next_head_id: Arc, // Version number of each point. // TODO(Sanket): Finer grained locking for this map in future if perf is not satisfactory. - pub versions_map: Arc>, + pub versions_map: Arc>, pub dimensionality: usize, pub params: DistributedSpannParameters, } @@ -155,7 +155,7 @@ impl SpannIndexWriter { blockfile_provider, posting_list_writer: Arc::new(tokio::sync::Mutex::new(posting_list_writer)), next_head_id: Arc::new(AtomicU32::new(next_head_id)), - versions_map: Arc::new(parking_lot::RwLock::new(versions_map)), + versions_map: Arc::new(tokio::sync::RwLock::new(versions_map)), dimensionality, params, } @@ -336,9 +336,9 @@ impl SpannIndexWriter { )) } - fn add_versions_map(&self, id: u32) -> u32 { + async fn add_versions_map(&self, id: u32) -> u32 { // 0 means deleted. Version counting starts from 1. - let mut write_lock = self.versions_map.write(); + let mut write_lock = self.versions_map.write().await; write_lock.versions_map.insert(id, 1); *write_lock.versions_map.get(&id).unwrap() } @@ -365,7 +365,7 @@ impl SpannIndexWriter { doc_offset_id: u32, version: u32, ) -> Result { - let version_map_guard = self.versions_map.read(); + let version_map_guard = self.versions_map.read().await; let current_version = version_map_guard .versions_map .get(&doc_offset_id) @@ -482,7 +482,7 @@ impl SpannIndexWriter { // Increment version and trigger append. let next_version; { - let mut version_map_guard = self.versions_map.write(); + let mut version_map_guard = self.versions_map.write().await; let current_version = version_map_guard .versions_map .get(&doc_offset_id) @@ -696,7 +696,7 @@ impl SpannIndexWriter { let mut local_indices = vec![0; doc_offset_ids.len()]; let mut up_to_date_index = 0; { - let version_map_guard = self.versions_map.read(); + let version_map_guard = self.versions_map.read().await; for (index, doc_version) in doc_versions.iter().enumerate() { let current_version = version_map_guard .versions_map @@ -969,7 +969,7 @@ impl SpannIndexWriter { } pub async fn add(&self, id: u32, embedding: &[f32]) -> Result<(), SpannIndexWriterError> { - let version = self.add_versions_map(id); + let version = self.add_versions_map(id).await; // Normalize the embedding in case of cosine. let mut normalized_embedding = embedding.to_vec(); let distance_function: DistanceFunction = self.params.space.clone().into(); @@ -985,7 +985,7 @@ impl SpannIndexWriter { let inc_version; { // Increment version. - let mut version_map_guard = self.versions_map.write(); + let mut version_map_guard = self.versions_map.write().await; let curr_version = match version_map_guard.versions_map.get(&id) { Some(version) => *version, None => { @@ -1012,7 +1012,7 @@ impl SpannIndexWriter { } pub async fn delete(&self, id: u32) -> Result<(), SpannIndexWriterError> { - let mut version_map_guard = self.versions_map.write(); + let mut version_map_guard = self.versions_map.write().await; version_map_guard.versions_map.insert(id, 0); Ok(()) } @@ -1023,7 +1023,7 @@ impl SpannIndexWriter { doc_versions: &[u32], ) -> Result { let mut up_to_date_index = 0; - let version_map_guard = self.versions_map.read(); + let version_map_guard = self.versions_map.read().await; for (index, doc_version) in doc_versions.iter().enumerate() { let current_version = version_map_guard .versions_map @@ -1058,7 +1058,7 @@ impl SpannIndexWriter { let mut cluster_len = 0; let mut local_indices = vec![0; doc_offset_ids.len()]; { - let version_map_guard = self.versions_map.read(); + let version_map_guard = self.versions_map.read().await; for (index, doc_version) in doc_versions.iter().enumerate() { let current_version = version_map_guard .versions_map @@ -1335,18 +1335,14 @@ impl SpannIndexWriter { // TODO(Sanket): Change the error types. pub async fn commit(self) -> Result { + // NOTE(Sanket): This is not the best way to drain the writer but the orchestrator keeps a + // reference to the writer so cannot do an Arc::try_unwrap() here. // Pl list. - let pl_flusher = match Arc::try_unwrap(self.posting_list_writer) { - Ok(writer) => writer - .into_inner() - .commit::>() - .await - .map_err(|_| SpannIndexWriterError::PostingListCommitError)?, - Err(_) => { - // This should never happen. - panic!("Failed to unwrap posting list writer"); - } - }; + let pl_writer_clone = self.posting_list_writer.lock().await.clone(); + let pl_flusher = pl_writer_clone + .commit::>() + .await + .map_err(|_| SpannIndexWriterError::PostingListCommitError)?; // Versions map. Create a writer, write all the data and commit. let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); @@ -1355,25 +1351,19 @@ impl SpannIndexWriter { .write::(bf_options) .await .map_err(|_| SpannIndexWriterError::VersionsMapWriterCreateError)?; - let versions_map_flusher = match Arc::try_unwrap(self.versions_map) { - Ok(writer) => { - let writer = writer.into_inner(); - for (doc_offset_id, doc_version) in writer.versions_map.into_iter() { - versions_map_bf_writer - .set("", doc_offset_id, doc_version) - .await - .map_err(|_| SpannIndexWriterError::VersionsMapSetError)?; - } + { + let mut version_map_guard = self.versions_map.write().await; + for (doc_offset_id, doc_version) in version_map_guard.versions_map.drain() { versions_map_bf_writer - .commit::() + .set("", doc_offset_id, doc_version) .await - .map_err(|_| SpannIndexWriterError::VersionsMapCommitError)? + .map_err(|_| SpannIndexWriterError::VersionsMapSetError)?; } - Err(_) => { - // This should never happen. - panic!("Failed to unwrap posting list writer"); - } - }; + } + let versions_map_flusher = versions_map_bf_writer + .commit::() + .await + .map_err(|_| SpannIndexWriterError::VersionsMapCommitError)?; // Next head. let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); @@ -1382,23 +1372,15 @@ impl SpannIndexWriter { .write::<&str, u32>(bf_options) .await .map_err(|_| SpannIndexWriterError::MaxHeadIdWriterCreateError)?; - let max_head_id_flusher = match Arc::try_unwrap(self.next_head_id) { - Ok(value) => { - let value = value.into_inner(); - max_head_id_bf - .set("", MAX_HEAD_OFFSET_ID, value) - .await - .map_err(|_| SpannIndexWriterError::MaxHeadIdSetError)?; - max_head_id_bf - .commit::<&str, u32>() - .await - .map_err(|_| SpannIndexWriterError::MaxHeadIdCommitError)? - } - Err(_) => { - // This should never happen. - panic!("Failed to unwrap next head id"); - } - }; + let max_head_oid = self.next_head_id.load(std::sync::atomic::Ordering::SeqCst); + max_head_id_bf + .set("", MAX_HEAD_OFFSET_ID, max_head_oid) + .await + .map_err(|_| SpannIndexWriterError::MaxHeadIdSetError)?; + let max_head_id_flusher = max_head_id_bf + .commit::<&str, u32>() + .await + .map_err(|_| SpannIndexWriterError::MaxHeadIdCommitError)?; let hnsw_id = self.hnsw_index.inner.read().id; @@ -1995,7 +1977,7 @@ mod tests { } // Insert the points in the version map as well. { - let mut version_map_guard = writer.versions_map.write(); + let mut version_map_guard = writer.versions_map.write().await; for point in 1..=100 { version_map_guard.versions_map.insert(point as u32, 1); version_map_guard.versions_map.insert(100 + point as u32, 1); @@ -2014,7 +1996,7 @@ mod tests { } // Expect the version map to be properly updated. { - let version_map_guard = writer.versions_map.read(); + let version_map_guard = writer.versions_map.read().await; for point in 1..=40 { assert_eq!(version_map_guard.versions_map.get(&point), Some(&0)); assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); @@ -2175,7 +2157,7 @@ mod tests { } // Insert the points in the version map as well. { - let mut version_map_guard = writer.versions_map.write(); + let mut version_map_guard = writer.versions_map.write().await; for point in 1..=100 { version_map_guard.versions_map.insert(point as u32, 1); version_map_guard.versions_map.insert(100 + point as u32, 1); @@ -2200,7 +2182,7 @@ mod tests { .expect("Error deleting from spann index writer"); // Expect the version map to be properly updated. { - let version_map_guard = writer.versions_map.read(); + let version_map_guard = writer.versions_map.read().await; for point in 1..=60 { assert_eq!(version_map_guard.versions_map.get(&point), Some(&0)); assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); @@ -2417,7 +2399,7 @@ mod tests { } // Insert these 150 points to version map. { - let mut version_map_guard = writer.versions_map.write(); + let mut version_map_guard = writer.versions_map.write().await; for i in 1..=150 { version_map_guard.versions_map.insert(i as u32, 1); } @@ -2653,7 +2635,7 @@ mod tests { } // Initialize the versions map appropriately. { - let mut version_map_guard = writer.versions_map.write(); + let mut version_map_guard = writer.versions_map.write().await; for i in 1..=160 { version_map_guard.versions_map.insert(i as u32, 1); } diff --git a/rust/segment/src/distributed_spann.rs b/rust/segment/src/distributed_spann.rs index 0bf760fcf7c..65ec75f0ad1 100644 --- a/rust/segment/src/distributed_spann.rs +++ b/rust/segment/src/distributed_spann.rs @@ -231,6 +231,7 @@ impl SpannSegmentWriter { record_segment_reader: &Option>, materialized_chunk: &MaterializeLogsResult, ) -> Result<(), ApplyMaterializedLogError> { + println!("(Sanket-temp) Applying materialized log chunk to spann segment writer"); for record in materialized_chunk { match record.get_operation() { MaterializedLogOperation::AddNew => { @@ -262,6 +263,7 @@ impl SpannSegmentWriter { ), } } + println!("(Sanket-temp) Finished applying materialized log chunk to spann segment writer"); Ok(()) } @@ -653,7 +655,7 @@ mod test { 2 ); { - let read_guard = spann_writer.index.versions_map.read(); + let read_guard = spann_writer.index.versions_map.read().await; assert_eq!(read_guard.versions_map.len(), 2); assert_eq!( *read_guard From 6d0943e25c813e1736612646cec5ed601b8f93f6 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Sun, 9 Mar 2025 18:10:03 -0700 Subject: [PATCH 4/4] Query orchestrator fix --- .../src/execution/orchestration/knn_filter.rs | 66 ++++++++++++------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/rust/worker/src/execution/orchestration/knn_filter.rs b/rust/worker/src/execution/orchestration/knn_filter.rs index adf334cc1c7..03ffc5a57a9 100644 --- a/rust/worker/src/execution/orchestration/knn_filter.rs +++ b/rust/worker/src/execution/orchestration/knn_filter.rs @@ -10,7 +10,10 @@ use chroma_system::{ wrap, ChannelError, ComponentContext, ComponentHandle, Dispatcher, Handler, Orchestrator, PanicError, TaskError, TaskMessage, TaskResult, }; -use chroma_types::{CollectionAndSegments, DistributedHnswParameters, Segment}; +use chroma_types::{ + CollectionAndSegments, DistributedHnswParameters, DistributedSpannParameters, Segment, + SegmentType, +}; use thiserror::Error; use tokio::sync::oneshot::{error::RecvError, Sender}; @@ -277,37 +280,54 @@ impl Handler> for KnnFilterOrchestrator { None => return, }; - let hnsw_configuration = match self.ok_or_terminate( - DistributedHnswParameters::try_from(&self.collection_and_segments.vector_segment) - .map_err(|_| KnnError::InvalidDistanceFunction), - ctx, - ) { - Some(hnsw_configuration) => hnsw_configuration, - None => return, - }; - let hnsw_reader = match DistributedHNSWSegmentReader::from_segment( - &self.collection_and_segments.vector_segment, - collection_dimension as usize, - self.hnsw_provider.clone(), - ) - .await + let (hnsw_reader, distance_function) = if self.collection_and_segments.vector_segment.r#type + == SegmentType::HnswDistributed { - Ok(hnsw_reader) => Some(hnsw_reader), - Err(err) if matches!(*err, DistributedHNSWSegmentFromSegmentError::Uninitialized) => { - None - } + let hnsw_configuration = match self.ok_or_terminate( + DistributedHnswParameters::try_from(&self.collection_and_segments.vector_segment) + .map_err(|_| KnnError::InvalidDistanceFunction), + ctx, + ) { + Some(hnsw_configuration) => hnsw_configuration, + None => return, + }; + match DistributedHNSWSegmentReader::from_segment( + &self.collection_and_segments.vector_segment, + collection_dimension as usize, + self.hnsw_provider.clone(), + ) + .await + { + Ok(hnsw_reader) => (Some(hnsw_reader), hnsw_configuration.space.into()), + Err(err) + if matches!(*err, DistributedHNSWSegmentFromSegmentError::Uninitialized) => + { + (None, hnsw_configuration.space.into()) + } - Err(err) => { - self.terminate_with_result(Err((*err).into()), ctx); - return; + Err(err) => { + self.terminate_with_result(Err((*err).into()), ctx); + return; + } } + } else { + let params = match self.ok_or_terminate( + DistributedSpannParameters::try_from(&self.collection_and_segments.vector_segment) + .map_err(|_| KnnError::InvalidDistanceFunction), + ctx, + ) { + Some(params) => params, + None => return, + }; + (None, params.space.into()) }; + let output = KnnFilterOutput { logs: self .fetched_logs .take() .expect("FetchLogOperator should have finished already"), - distance_function: hnsw_configuration.space.into(), + distance_function, filter_output: output, hnsw_reader, record_segment: self.collection_and_segments.record_segment.clone(),