From fe7e3d2accb96586a2363cc83458d2029ceb4a2b Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 4 Mar 2025 12:59:55 -0800 Subject: [PATCH] Add get impl --- rust/frontend/src/impls/in_memory_frontend.rs | 123 +++++++++++++++++- rust/types/src/api_types.rs | 12 +- 2 files changed, 123 insertions(+), 12 deletions(-) diff --git a/rust/frontend/src/impls/in_memory_frontend.rs b/rust/frontend/src/impls/in_memory_frontend.rs index 4d6df392c14..2138d71c1ea 100644 --- a/rust/frontend/src/impls/in_memory_frontend.rs +++ b/rust/frontend/src/impls/in_memory_frontend.rs @@ -1,9 +1,9 @@ use chroma_error::ChromaError; use chroma_segment::test::TestReferenceSegment; -use chroma_types::operator::Scan; -use chroma_types::plan::Count; +use chroma_types::operator::{Filter, Limit, Projection, Scan}; +use chroma_types::plan::{Count, Get}; use chroma_types::{ - test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Segment, + test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Include, Segment, }; use parking_lot::Mutex; use std::collections::HashSet; @@ -403,9 +403,63 @@ impl InMemoryFrontend { pub async fn get( &mut self, - _request: chroma_types::GetRequest, + request: chroma_types::GetRequest, ) -> Result { - todo!() + let inner = self.inner.lock(); + let collection = inner + .collections + .iter() + .find(|c| { + c.collection.collection_id == request.collection_id + && c.collection.tenant == request.tenant_id + && c.collection.database == request.database_name + }) + .ok_or( + chroma_types::GetCollectionError::NotFound(request.collection_id.to_string()) + .boxed(), + )?; + + let chroma_types::GetRequest { + ids, + include, + r#where, + offset, + limit, + .. + } = request; + + let filter = Filter { + query_ids: ids, + where_clause: r#where, + }; + + let get_response = collection + .reference_impl + .get(Get { + scan: Scan { + collection_and_segments: CollectionAndSegments { + collection: collection.collection.clone(), + metadata_segment: collection.metadata_segment.clone(), + vector_segment: collection.vector_segment.clone(), + record_segment: collection.record_segment.clone(), + }, + }, + filter, + limit: Limit { + skip: offset, + fetch: limit, + }, + proj: Projection { + document: include.0.contains(&Include::Document), + embedding: include.0.contains(&Include::Embedding), + // If URI is requested, metadata is also requested so we can extract the URI. + metadata: (include.0.contains(&Include::Metadata) + || include.0.contains(&Include::Uri)), + }, + }) + .map_err(|e| e.boxed())?; + + Ok((get_response, include).into()) } pub async fn query( @@ -424,6 +478,11 @@ impl InMemoryFrontend { #[cfg(test)] mod tests { + use chroma_types::{ + DocumentExpression, IncludeList, Metadata, MetadataComparison, MetadataExpression, + MetadataValue, PrimitiveOperator, Where, + }; + use super::*; #[tokio::test] @@ -458,6 +517,16 @@ mod tests { let embeddings = vec![vec![1.0, 1.0, 1.0], vec![2.0, 2.0, 2.0]]; let documents = vec![Some("doc1".to_string()), Some("doc2".to_string())]; + let mut metadata1 = Metadata::new(); + metadata1.insert("key1".to_string(), MetadataValue::Str("value1".to_string())); + metadata1.insert("key2".to_string(), MetadataValue::Int(16)); + + let mut metadata2 = Metadata::new(); + metadata2.insert("key1".to_string(), MetadataValue::Str("value2".to_string())); + metadata2.insert("key2".to_string(), MetadataValue::Int(32)); + + let metadatas = vec![Some(metadata1), Some(metadata2)]; + let request = chroma_types::AddCollectionRecordsRequest::try_new( tenant_name.clone(), database_name.clone(), @@ -466,11 +535,12 @@ mod tests { Some(embeddings), Some(documents), None, - None, + Some(metadatas), ) .unwrap(); frontend.add(request).await.unwrap(); + // Test count let count = frontend .count( chroma_types::CountRequest::try_new( @@ -483,5 +553,46 @@ mod tests { .await .unwrap(); assert_eq!(count, 2); + + // Test metadata filter + let request = chroma_types::GetRequest::try_new( + tenant_name.clone(), + database_name.clone(), + collection.collection_id, + None, + Some(Where::Metadata(MetadataExpression { + key: "key1".to_string(), + comparison: MetadataComparison::Primitive( + PrimitiveOperator::Equal, + MetadataValue::Str("value1".to_string()), + ), + })), + None, + 0, + IncludeList::default_get(), + ) + .unwrap(); + let response = frontend.get(request).await.unwrap(); + assert_eq!(response.ids.len(), 1); + assert_eq!(response.ids[0], "id1"); + + // Test full text index + let request = chroma_types::GetRequest::try_new( + tenant_name.clone(), + database_name.clone(), + collection.collection_id, + None, + Some(Where::Document(DocumentExpression { + operator: chroma_types::DocumentOperator::Contains, + text: "doc2".to_string(), + })), + None, + 0, + IncludeList::default_get(), + ) + .unwrap(); + let response = frontend.get(request).await.unwrap(); + assert_eq!(response.ids.len(), 1); + assert_eq!(response.ids[0], "id2"); } } diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 0ece57481d6..7b529641a1a 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -1140,13 +1140,13 @@ impl GetRequest { #[derive(Clone, Deserialize, Serialize, Debug, ToSchema)] #[cfg_attr(feature = "pyo3", pyo3::pyclass)] pub struct GetResponse { - ids: Vec, - embeddings: Option>>, - documents: Option>>, - uris: Option>>, + pub ids: Vec, + pub embeddings: Option>>, + pub documents: Option>>, + pub uris: Option>>, // TODO(hammadb): Add metadata & include to the response - metadatas: Option>>, - include: Vec, + pub metadatas: Option>>, + pub include: Vec, } #[cfg(feature = "pyo3")]