From 7b62478687b4ba44aab806792b889b37c688de71 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 4 Mar 2025 13:59:48 -0800 Subject: [PATCH] Fix impl --- rust/frontend/src/impls/in_memory_frontend.rs | 256 +++++++++--------- rust/frontend/src/impls/mod.rs | 256 +----------------- rust/segment/src/test.rs | 83 ++++-- 3 files changed, 191 insertions(+), 404 deletions(-) diff --git a/rust/frontend/src/impls/in_memory_frontend.rs b/rust/frontend/src/impls/in_memory_frontend.rs index d36d56023f7..d840483765d 100644 --- a/rust/frontend/src/impls/in_memory_frontend.rs +++ b/rust/frontend/src/impls/in_memory_frontend.rs @@ -1,17 +1,16 @@ +use super::utils::to_records; +use chroma_distance::DistanceFunction; use chroma_error::ChromaError; use chroma_segment::test::TestReferenceSegment; use chroma_types::operator::{Filter, KnnBatch, KnnProjection, Limit, Projection, Scan}; use chroma_types::plan::{Count, Get, Knn}; use chroma_types::{ test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Include, - IncludeList, Segment, + IncludeList, Segment, SingleNodeHnswParameters, }; -use parking_lot::Mutex; use std::collections::HashSet; -use std::sync::Arc; - -use super::utils::to_records; +#[derive(Debug, Clone)] struct InMemoryCollection { collection: Collection, metadata_segment: Segment, @@ -20,16 +19,16 @@ struct InMemoryCollection { reference_impl: TestReferenceSegment, } -#[derive(Default)] +#[derive(Default, Debug, Clone)] struct Inner { tenants: HashSet, databases: Vec, collections: Vec, } -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct InMemoryFrontend { - inner: Arc>, + inner: Inner, } impl InMemoryFrontend { @@ -37,14 +36,12 @@ impl InMemoryFrontend { Default::default() } - pub async fn reset(&mut self) -> Result { - let mut inner = self.inner.lock(); - *inner = Inner::default(); - + pub fn reset(&mut self) -> Result { + self.inner = Inner::default(); Ok(chroma_types::ResetResponse {}) } - pub async fn heartbeat( + pub fn heartbeat( &self, ) -> Result { Ok(chroma_types::HeartbeatResponse { @@ -53,16 +50,14 @@ impl InMemoryFrontend { } pub fn get_max_batch_size(&mut self) -> u32 { - 1024 // Example placeholder + 1024 } - pub async fn create_tenant( + pub fn create_tenant( &mut self, request: chroma_types::CreateTenantRequest, ) -> Result { - let mut inner = self.inner.lock(); - - let was_new = inner.tenants.insert(request.name.clone()); + let was_new = self.inner.tenants.insert(request.name.clone()); if !was_new { return Err(chroma_types::CreateTenantError::AlreadyExists(request.name)); } @@ -70,25 +65,22 @@ impl InMemoryFrontend { Ok(chroma_types::CreateTenantResponse {}) } - pub async fn get_tenant( + pub fn get_tenant( &mut self, request: chroma_types::GetTenantRequest, ) -> Result { - let inner = self.inner.lock(); - if inner.tenants.contains(&request.name) { + if self.inner.tenants.contains(&request.name) { Ok(chroma_types::GetTenantResponse { name: request.name }) } else { Err(chroma_types::GetTenantError::NotFound(request.name)) } } - pub async fn create_database( + pub fn create_database( &mut self, request: chroma_types::CreateDatabaseRequest, ) -> Result { - let mut inner = self.inner.lock(); - - if inner.databases.iter().any(|db| { + if self.inner.databases.iter().any(|db| { db.id == request.database_id || (db.name == request.database_name && db.tenant == request.tenant_id) }) { @@ -97,7 +89,7 @@ impl InMemoryFrontend { )); } - inner.databases.push(Database { + self.inner.databases.push(Database { id: request.database_id, name: request.database_name, tenant: request.tenant_id, @@ -106,12 +98,12 @@ impl InMemoryFrontend { Ok(chroma_types::CreateDatabaseResponse {}) } - pub async fn list_databases( + pub fn list_databases( &mut self, request: chroma_types::ListDatabasesRequest, ) -> Result { - let inner = self.inner.lock(); - let databases: Vec<_> = inner + let databases: Vec<_> = self + .inner .databases .iter() .filter(|db| db.tenant == request.tenant_id) @@ -123,12 +115,12 @@ impl InMemoryFrontend { .to_vec()) } - pub async fn get_database( + pub fn get_database( &mut self, request: chroma_types::GetDatabaseRequest, ) -> Result { - let inner = self.inner.lock(); - if let Some(db) = inner + if let Some(db) = self + .inner .databases .iter() .find(|db| db.name == request.database_name && db.tenant == request.tenant_id) @@ -141,17 +133,17 @@ impl InMemoryFrontend { } } - pub async fn delete_database( + pub fn delete_database( &mut self, request: chroma_types::DeleteDatabaseRequest, ) -> Result { - let mut inner = self.inner.lock(); - if let Some(pos) = inner + if let Some(pos) = self + .inner .databases .iter() .position(|db| db.name == request.database_name && db.tenant == request.tenant_id) { - inner.databases.remove(pos); + self.inner.databases.remove(pos); Ok(chroma_types::DeleteDatabaseResponse {}) } else { Err(chroma_types::DeleteDatabaseError::NotFound( @@ -160,12 +152,12 @@ impl InMemoryFrontend { } } - pub async fn list_collections( + pub fn list_collections( &mut self, request: chroma_types::ListCollectionsRequest, ) -> Result { - let inner = self.inner.lock(); - let collections: Vec<_> = inner + let collections: Vec<_> = self + .inner .collections .iter() .filter(|c| { @@ -180,12 +172,12 @@ impl InMemoryFrontend { .to_vec()) } - pub async fn count_collections( + pub fn count_collections( &mut self, request: chroma_types::CountCollectionsRequest, ) -> Result { - let inner = self.inner.lock(); - let count = inner + let count = self + .inner .collections .iter() .filter(|c| { @@ -197,12 +189,11 @@ impl InMemoryFrontend { Ok(count as u32) } - pub async fn get_collection( + pub fn get_collection( &mut self, request: chroma_types::GetCollectionRequest, ) -> Result { - let inner = self.inner.lock(); - if let Some(collection) = inner.collections.iter().find(|c| { + if let Some(collection) = self.inner.collections.iter().find(|c| { c.collection.name == request.collection_name && c.collection.tenant == request.tenant_id }) { Ok(collection.collection.clone()) @@ -213,13 +204,11 @@ impl InMemoryFrontend { } } - pub async fn create_collection( + pub fn create_collection( &mut self, request: chroma_types::CreateCollectionRequest, ) -> Result { - let mut inner = self.inner.lock(); - - if inner.collections.iter().any(|c| { + if self.inner.collections.iter().any(|c| { c.collection.name == request.name && c.collection.tenant == request.tenant_id && c.collection.database == request.database_name @@ -240,43 +229,48 @@ impl InMemoryFrontend { log_position: 0, version: 0, total_records_post_compaction: 0, + size_bytes_post_compaction: 0, + last_compaction_time_secs: 0, }; - let reference_impl = TestReferenceSegment::default(); - - inner.collections.push(InMemoryCollection { + let metadata_segment = test_segment( + collection.collection_id, + chroma_types::SegmentScope::METADATA, + ); + let vector_segment = + test_segment(collection.collection_id, chroma_types::SegmentScope::VECTOR); + let record_segment = + test_segment(collection.collection_id, chroma_types::SegmentScope::RECORD); + + let mut reference_impl = TestReferenceSegment::default(); + reference_impl.create_segment(metadata_segment.clone()); + reference_impl.create_segment(vector_segment.clone()); + reference_impl.create_segment(record_segment.clone()); + + self.inner.collections.push(InMemoryCollection { collection: collection.clone(), - metadata_segment: test_segment( - collection.collection_id, - chroma_types::SegmentScope::METADATA, - ), - vector_segment: test_segment( - collection.collection_id, - chroma_types::SegmentScope::VECTOR, - ), - record_segment: test_segment( - collection.collection_id, - chroma_types::SegmentScope::RECORD, - ), + metadata_segment, + vector_segment, + record_segment, reference_impl, }); Ok(collection) } - pub async fn update_collection( + pub fn update_collection( &mut self, _request: chroma_types::UpdateCollectionRequest, ) -> Result { unimplemented!() } - pub async fn delete_collection( + pub fn delete_collection( &mut self, request: chroma_types::DeleteCollectionRequest, ) -> Result { - let mut inner = self.inner.lock(); + let inner = &mut self.inner; if let Some(pos) = inner.collections.iter().position(|c| { c.collection.name == request.collection_name && c.collection.tenant == request.tenant_id @@ -291,13 +285,13 @@ impl InMemoryFrontend { } } - pub async fn add( + pub fn add( &mut self, request: chroma_types::AddCollectionRecordsRequest, ) -> Result { - let mut inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter_mut() .find(|c| { @@ -337,15 +331,15 @@ impl InMemoryFrontend { Ok(chroma_types::AddCollectionRecordsResponse {}) } - pub async fn update( + pub fn update( &mut self, request: chroma_types::UpdateCollectionRecordsRequest, ) -> Result< chroma_types::UpdateCollectionRecordsResponse, chroma_types::UpdateCollectionRecordsError, > { - let mut inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter_mut() .find(|c| { @@ -384,15 +378,15 @@ impl InMemoryFrontend { Ok(chroma_types::UpdateCollectionRecordsResponse {}) } - pub async fn upsert( + pub fn upsert( &mut self, request: chroma_types::UpsertCollectionRecordsRequest, ) -> Result< chroma_types::UpsertCollectionRecordsResponse, chroma_types::UpsertCollectionRecordsError, > { - let mut inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter_mut() .find(|c| { @@ -422,7 +416,7 @@ impl InMemoryFrontend { documents, uris, metadatas, - chroma_types::Operation::Add, + chroma_types::Operation::Upsert, ) .map_err(|e| e.boxed())?; @@ -433,7 +427,7 @@ impl InMemoryFrontend { Ok(chroma_types::UpsertCollectionRecordsResponse {}) } - pub async fn delete( + pub fn delete( &mut self, request: chroma_types::DeleteCollectionRecordsRequest, ) -> Result< @@ -454,12 +448,11 @@ impl InMemoryFrontend { ) .unwrap(), ) - .await .map_err(|e| e.boxed()) .map(|response| response.ids)?; - let mut inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter_mut() .find(|c| { @@ -490,12 +483,12 @@ impl InMemoryFrontend { Ok(chroma_types::DeleteCollectionRecordsResponse {}) } - pub async fn count( + pub fn count( &mut self, request: chroma_types::CountRequest, ) -> Result { - let inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter() .find(|c| { @@ -525,12 +518,12 @@ impl InMemoryFrontend { Ok(count) } - pub async fn get( - &mut self, + pub fn get( + &self, request: chroma_types::GetRequest, ) -> Result { - let inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter() .find(|c| { @@ -586,12 +579,12 @@ impl InMemoryFrontend { Ok((get_response, include).into()) } - pub async fn query( + pub fn query( &mut self, request: chroma_types::QueryRequest, ) -> Result { - let inner = self.inner.lock(); - let collection = inner + let collection = self + .inner .collections .iter() .find(|c| { @@ -618,39 +611,46 @@ impl InMemoryFrontend { where_clause: r#where, }; + let params = SingleNodeHnswParameters::try_from(&collection.vector_segment) + .map_err(|e| e.boxed())?; + let distance_function: DistanceFunction = params.space.into(); + let query_response = collection .reference_impl - .knn(Knn { - scan: Scan { - collection_and_segments: CollectionAndSegments { - collection: collection.collection.clone(), - metadata_segment: collection.metadata_segment.clone(), - vector_segment: collection.vector_segment.clone(), - record_segment: collection.record_segment.clone(), + .knn( + Knn { + scan: Scan { + collection_and_segments: CollectionAndSegments { + collection: collection.collection.clone(), + metadata_segment: collection.metadata_segment.clone(), + vector_segment: collection.vector_segment.clone(), + record_segment: collection.record_segment.clone(), + }, }, - }, - filter, - knn: KnnBatch { - embeddings, - fetch: n_results, - }, - proj: KnnProjection { - projection: Projection { - document: include.0.contains(&Include::Document), - embedding: include.0.contains(&Include::Embedding), - // If URI is requested, metadata is also requested so we can extract the URI. - metadata: (include.0.contains(&Include::Metadata) - || include.0.contains(&Include::Uri)), + filter, + knn: KnnBatch { + embeddings, + fetch: n_results, + }, + proj: KnnProjection { + projection: Projection { + document: include.0.contains(&Include::Document), + embedding: include.0.contains(&Include::Embedding), + // If URI is requested, metadata is also requested so we can extract the URI. + metadata: (include.0.contains(&Include::Metadata) + || include.0.contains(&Include::Uri)), + }, + distance: include.0.contains(&Include::Distance), }, - distance: include.0.contains(&Include::Distance), }, - }) + distance_function, + ) .map_err(|e| e.boxed())?; Ok((query_response, include).into()) } - pub async fn healthcheck(&self) -> chroma_types::HealthCheckResponse { + pub fn healthcheck(&self) -> chroma_types::HealthCheckResponse { chroma_types::HealthCheckResponse { is_executor_ready: true, } @@ -666,21 +666,21 @@ mod tests { use super::*; - async fn create_test_collection() -> (InMemoryFrontend, Collection) { + fn create_test_collection() -> (InMemoryFrontend, Collection) { let tenant_name = "test".to_string(); let database_name = "test".to_string(); let collection_name = "test".to_string(); let mut frontend = InMemoryFrontend::new(); let request = chroma_types::CreateTenantRequest::try_new(tenant_name.clone()).unwrap(); - frontend.create_tenant(request).await.unwrap(); + frontend.create_tenant(request).unwrap(); let request = chroma_types::CreateDatabaseRequest::try_new( tenant_name.clone(), database_name.clone(), ) .unwrap(); - frontend.create_database(request).await.unwrap(); + frontend.create_database(request).unwrap(); let request = chroma_types::CreateCollectionRequest::try_new( tenant_name.clone(), @@ -691,15 +691,14 @@ mod tests { false, ) .unwrap(); - ( - frontend.clone(), - frontend.create_collection(request).await.unwrap(), - ) + + let collection = frontend.create_collection(request).unwrap(); + (frontend, collection) } - #[tokio::test] - async fn test_collection_get_query() { - let (mut frontend, collection) = create_test_collection().await; + #[test] + fn test_collection_get_query() { + let (mut frontend, collection) = create_test_collection(); let ids = vec!["id1".to_string(), "id2".to_string()]; let embeddings = vec![vec![-1.0, -1.0, -1.0], vec![0.0, 0.0, 0.0]]; let documents = vec![Some("doc1".to_string()), Some("doc2".to_string())]; @@ -725,7 +724,7 @@ mod tests { Some(metadatas), ) .unwrap(); - frontend.add(request).await.unwrap(); + frontend.add(request).unwrap(); // Test count let count = frontend @@ -737,7 +736,6 @@ mod tests { ) .unwrap(), ) - .await .unwrap(); assert_eq!(count, 2); @@ -759,7 +757,7 @@ mod tests { IncludeList::default_get(), ) .unwrap(); - let response = frontend.get(request).await.unwrap(); + let response = frontend.get(request).unwrap(); assert_eq!(response.ids.len(), 1); assert_eq!(response.ids[0], "id1"); @@ -778,7 +776,7 @@ mod tests { IncludeList::default_get(), ) .unwrap(); - let response = frontend.get(request).await.unwrap(); + let response = frontend.get(request).unwrap(); assert_eq!(response.ids.len(), 1); assert_eq!(response.ids[0], "id2"); @@ -794,7 +792,7 @@ mod tests { IncludeList::default_query(), ) .unwrap(); - let response = frontend.query(request).await.unwrap(); + let response = frontend.query(request).unwrap(); assert_eq!(response.ids[0].len(), 2); assert_eq!(response.ids[0], vec!["id2", "id1"]); } diff --git a/rust/frontend/src/impls/mod.rs b/rust/frontend/src/impls/mod.rs index 69ddd0eba1c..b310c66e774 100644 --- a/rust/frontend/src/impls/mod.rs +++ b/rust/frontend/src/impls/mod.rs @@ -2,260 +2,6 @@ pub mod in_memory_frontend; pub mod service_based_frontend; mod utils; -use chroma_config::Configurable; -use chroma_error::ChromaError; -use chroma_system::System; -use in_memory_frontend::InMemoryFrontend; use service_based_frontend::ServiceBasedFrontend; -use crate::FrontendConfig; - -#[derive(Clone)] -pub enum Frontend { - ServiceBased(ServiceBasedFrontend), - InMemory(InMemoryFrontend), -} - -impl Frontend { - pub async fn reset(&mut self) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.reset().await, - Frontend::InMemory(frontend) => frontend.reset().await, - } - } - - pub async fn heartbeat( - &self, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.heartbeat().await, - Frontend::InMemory(frontend) => frontend.heartbeat().await, - } - } - - pub fn get_max_batch_size(&mut self) -> u32 { - match self { - Frontend::ServiceBased(frontend) => frontend.get_max_batch_size(), - Frontend::InMemory(frontend) => frontend.get_max_batch_size(), - } - } - - pub async fn create_tenant( - &mut self, - request: chroma_types::CreateTenantRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.create_tenant(request).await, - Frontend::InMemory(frontend) => frontend.create_tenant(request).await, - } - } - - pub async fn get_tenant( - &mut self, - request: chroma_types::GetTenantRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.get_tenant(request).await, - Frontend::InMemory(frontend) => frontend.get_tenant(request).await, - } - } - - pub async fn create_database( - &mut self, - request: chroma_types::CreateDatabaseRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.create_database(request).await, - Frontend::InMemory(frontend) => frontend.create_database(request).await, - } - } - - pub async fn list_databases( - &mut self, - request: chroma_types::ListDatabasesRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.list_databases(request).await, - Frontend::InMemory(frontend) => frontend.list_databases(request).await, - } - } - - pub async fn get_database( - &mut self, - request: chroma_types::GetDatabaseRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.get_database(request).await, - Frontend::InMemory(frontend) => frontend.get_database(request).await, - } - } - - pub async fn delete_database( - &mut self, - request: chroma_types::DeleteDatabaseRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.delete_database(request).await, - Frontend::InMemory(frontend) => frontend.delete_database(request).await, - } - } - - pub async fn list_collections( - &mut self, - request: chroma_types::ListCollectionsRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.list_collections(request).await, - Frontend::InMemory(frontend) => frontend.list_collections(request).await, - } - } - - pub async fn count_collections( - &mut self, - request: chroma_types::CountCollectionsRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.count_collections(request).await, - Frontend::InMemory(frontend) => frontend.count_collections(request).await, - } - } - - pub async fn get_collection( - &mut self, - request: chroma_types::GetCollectionRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.get_collection(request).await, - Frontend::InMemory(frontend) => frontend.get_collection(request).await, - } - } - - pub async fn create_collection( - &mut self, - request: chroma_types::CreateCollectionRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.create_collection(request).await, - Frontend::InMemory(frontend) => frontend.create_collection(request).await, - } - } - - pub async fn update_collection( - &mut self, - request: chroma_types::UpdateCollectionRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.update_collection(request).await, - Frontend::InMemory(frontend) => frontend.update_collection(request).await, - } - } - - pub async fn delete_collection( - &mut self, - request: chroma_types::DeleteCollectionRequest, - ) -> Result - { - match self { - Frontend::ServiceBased(frontend) => frontend.delete_collection(request).await, - Frontend::InMemory(frontend) => frontend.delete_collection(request).await, - } - } - - pub async fn add( - &mut self, - request: chroma_types::AddCollectionRecordsRequest, - ) -> Result - { - match self { - Frontend::ServiceBased(frontend) => frontend.add(request).await, - Frontend::InMemory(frontend) => frontend.add(request).await, - } - } - - pub async fn update( - &mut self, - request: chroma_types::UpdateCollectionRecordsRequest, - ) -> Result< - chroma_types::UpdateCollectionRecordsResponse, - chroma_types::UpdateCollectionRecordsError, - > { - match self { - Frontend::ServiceBased(frontend) => frontend.update(request).await, - Frontend::InMemory(frontend) => frontend.update(request).await, - } - } - - pub async fn upsert( - &mut self, - request: chroma_types::UpsertCollectionRecordsRequest, - ) -> Result< - chroma_types::UpsertCollectionRecordsResponse, - chroma_types::UpsertCollectionRecordsError, - > { - match self { - Frontend::ServiceBased(frontend) => frontend.upsert(request).await, - Frontend::InMemory(frontend) => frontend.upsert(request).await, - } - } - - pub async fn delete( - &mut self, - request: chroma_types::DeleteCollectionRecordsRequest, - ) -> Result< - chroma_types::DeleteCollectionRecordsResponse, - chroma_types::DeleteCollectionRecordsError, - > { - match self { - Frontend::ServiceBased(frontend) => frontend.delete(request).await, - Frontend::InMemory(frontend) => frontend.delete(request).await, - } - } - - pub async fn count( - &mut self, - request: chroma_types::CountRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.count(request).await, - Frontend::InMemory(frontend) => frontend.count(request).await, - } - } - - pub async fn get( - &mut self, - request: chroma_types::GetRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.get(request).await, - Frontend::InMemory(frontend) => frontend.get(request).await, - } - } - - pub async fn query( - &mut self, - request: chroma_types::QueryRequest, - ) -> Result { - match self { - Frontend::ServiceBased(frontend) => frontend.query(request).await, - Frontend::InMemory(frontend) => frontend.query(request).await, - } - } - - pub async fn healthcheck(&self) -> chroma_types::HealthCheckResponse { - match self { - Frontend::ServiceBased(frontend) => frontend.healthcheck().await, - Frontend::InMemory(frontend) => frontend.healthcheck().await, - } - } -} -#[async_trait::async_trait] -impl Configurable<(FrontendConfig, System)> for Frontend { - async fn try_from_config( - config_and_system: &(FrontendConfig, System), - registry: &chroma_config::registry::Registry, - ) -> Result> { - ServiceBasedFrontend::try_from_config(config_and_system, registry) - .await - .map(Frontend::ServiceBased) - } -} +pub type Frontend = ServiceBasedFrontend; diff --git a/rust/segment/src/test.rs b/rust/segment/src/test.rs index 880fe0d1d81..bbe2d89a8c4 100644 --- a/rust/segment/src/test.rs +++ b/rust/segment/src/test.rs @@ -16,6 +16,7 @@ use chroma_types::{ DocumentExpression, DocumentOperator, LogRecord, Metadata, MetadataComparison, MetadataExpression, MetadataSetValue, MetadataValue, Operation, OperationRecord, PrimitiveOperator, Segment, SegmentScope, SegmentUuid, SetOperator, UpdateMetadata, Where, + CHROMA_KEY, }; use std::collections::BinaryHeap; use std::{ @@ -149,13 +150,17 @@ impl ChromaError for TestReferenceSegmentError { } } -#[derive(Default)] +#[derive(Default, Debug, Clone)] pub struct TestReferenceSegment { max_id: u32, record: HashMap>, } impl TestReferenceSegment { + pub fn new() -> Self { + Self::default() + } + fn merge_meta(old_meta: Option, delta: Option) -> Option { let (deleted_keys, new_meta) = if let Some(m) = delta { let mut dk = HashSet::new(); @@ -174,6 +179,7 @@ impl TestReferenceSegment { } else { (HashSet::new(), None) }; + let new_meta = new_meta.and_then(|meta| if meta.is_empty() { None } else { Some(meta) }); match (old_meta, new_meta) { (None, None) => None, (None, Some(m)) | (Some(m), None) => Some(m), @@ -186,6 +192,24 @@ impl TestReferenceSegment { } } + fn filter_metadata(metadata: Option) -> Option { + metadata.and_then(|metadata| { + let filtered: UpdateMetadata = metadata + .into_iter() + .filter(|(k, _)| !k.starts_with(CHROMA_KEY)) + .collect(); + if filtered.is_empty() { + None + } else { + Some(filtered) + } + }) + } + + pub fn create_segment(&mut self, segment: Segment) { + self.record.insert(segment.id, HashMap::new()); + } + pub fn apply_logs(&mut self, logs: Vec, segment_id: SegmentUuid) { self.apply_operation_records(logs.into_iter().map(|l| l.record).collect(), segment_id); } @@ -214,27 +238,41 @@ impl TestReferenceSegment { match operation { Operation::Add => { if let Entry::Vacant(entry) = coll.entry(id) { - record.metadata = Self::merge_meta(None, metadata); + record.metadata = Self::merge_meta(None, Self::filter_metadata(metadata)); entry.insert((self.max_id, record)); self.max_id += 1; } } Operation::Update => { if let Some((_, old_record)) = coll.get_mut(&id) { - old_record.document = record.document; - old_record.embedding = record.embedding; - old_record.metadata = - Self::merge_meta(old_record.metadata.clone(), metadata); + if record.document.is_some() { + old_record.document = record.document; + } + + if record.embedding.is_some() { + old_record.embedding = record.embedding; + } + + old_record.metadata = Self::merge_meta( + old_record.metadata.clone(), + Self::filter_metadata(metadata), + ); } } Operation::Upsert => { if let Some((_, old_record)) = coll.get_mut(&id) { - old_record.document = record.document; + if record.document.is_some() { + old_record.document = record.document; + } + old_record.embedding = record.embedding; - old_record.metadata = - Self::merge_meta(old_record.metadata.clone(), metadata); + + old_record.metadata = Self::merge_meta( + old_record.metadata.clone(), + Self::filter_metadata(metadata), + ); } else { - record.metadata = Self::merge_meta(None, metadata); + record.metadata = Self::merge_meta(None, Self::filter_metadata(metadata)); coll.insert(id, (self.max_id, record)); self.max_id += 1; } @@ -303,7 +341,11 @@ impl TestReferenceSegment { }) } - pub fn knn(&self, plan: Knn) -> Result> { + pub fn knn( + &self, + plan: Knn, + distance_function: DistanceFunction, + ) -> Result> { let coll = self .record .get(&plan.scan.collection_and_segments.metadata_segment.id) @@ -340,24 +382,25 @@ impl TestReferenceSegment { } impl Ord for RecordWithDistance { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.partial_cmp(other).unwrap() + self.0.partial_cmp(&other.0).unwrap() } } let mut result = KnnBatchResult::default(); for embedding in plan.knn.embeddings { - let mut max_heap = BinaryHeap::with_capacity(plan.knn.fetch as usize); + let mut max_heap: BinaryHeap = + BinaryHeap::with_capacity(plan.knn.fetch as usize * 100); let target_vector = normalize(&embedding); for (_, record) in &filtered_records { - let distance = DistanceFunction::Cosine.distance( - &target_vector, - record - .embedding - .as_ref() - .expect("Embedding should be present"), - ); + let distance = match &distance_function { + DistanceFunction::Cosine => distance_function.distance( + &target_vector, + &normalize(record.embedding.as_ref().unwrap()), + ), + other => other.distance(&embedding, record.embedding.as_ref().unwrap()), + }; if max_heap.len() < plan.knn.fetch as usize { max_heap.push(RecordWithDistance(distance, record.clone()));