Skip to content

Commit

Permalink
Add get impl
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Mar 7, 2025
1 parent 46eb43c commit fe7e3d2
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 12 deletions.
123 changes: 117 additions & 6 deletions rust/frontend/src/impls/in_memory_frontend.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use chroma_error::ChromaError;
use chroma_segment::test::TestReferenceSegment;
use chroma_types::operator::Scan;
use chroma_types::plan::Count;
use chroma_types::operator::{Filter, Limit, Projection, Scan};
use chroma_types::plan::{Count, Get};
use chroma_types::{
test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Segment,
test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Include, Segment,
};
use parking_lot::Mutex;
use std::collections::HashSet;
Expand Down Expand Up @@ -403,9 +403,63 @@ impl InMemoryFrontend {

pub async fn get(
&mut self,
_request: chroma_types::GetRequest,
request: chroma_types::GetRequest,
) -> Result<chroma_types::GetResponse, chroma_types::QueryError> {
todo!()
let inner = self.inner.lock();
let collection = inner
.collections
.iter()
.find(|c| {
c.collection.collection_id == request.collection_id
&& c.collection.tenant == request.tenant_id
&& c.collection.database == request.database_name
})
.ok_or(
chroma_types::GetCollectionError::NotFound(request.collection_id.to_string())
.boxed(),
)?;

let chroma_types::GetRequest {
ids,
include,
r#where,
offset,
limit,
..
} = request;

let filter = Filter {
query_ids: ids,
where_clause: r#where,
};

let get_response = collection
.reference_impl
.get(Get {
scan: Scan {
collection_and_segments: CollectionAndSegments {
collection: collection.collection.clone(),
metadata_segment: collection.metadata_segment.clone(),
vector_segment: collection.vector_segment.clone(),
record_segment: collection.record_segment.clone(),
},
},
filter,
limit: Limit {
skip: offset,
fetch: limit,
},
proj: Projection {
document: include.0.contains(&Include::Document),
embedding: include.0.contains(&Include::Embedding),
// If URI is requested, metadata is also requested so we can extract the URI.
metadata: (include.0.contains(&Include::Metadata)
|| include.0.contains(&Include::Uri)),
},
})
.map_err(|e| e.boxed())?;

Ok((get_response, include).into())
}

pub async fn query(
Expand All @@ -424,6 +478,11 @@ impl InMemoryFrontend {

#[cfg(test)]
mod tests {
use chroma_types::{
DocumentExpression, IncludeList, Metadata, MetadataComparison, MetadataExpression,
MetadataValue, PrimitiveOperator, Where,
};

use super::*;

#[tokio::test]
Expand Down Expand Up @@ -458,6 +517,16 @@ mod tests {
let embeddings = vec![vec![1.0, 1.0, 1.0], vec![2.0, 2.0, 2.0]];
let documents = vec![Some("doc1".to_string()), Some("doc2".to_string())];

let mut metadata1 = Metadata::new();
metadata1.insert("key1".to_string(), MetadataValue::Str("value1".to_string()));
metadata1.insert("key2".to_string(), MetadataValue::Int(16));

let mut metadata2 = Metadata::new();
metadata2.insert("key1".to_string(), MetadataValue::Str("value2".to_string()));
metadata2.insert("key2".to_string(), MetadataValue::Int(32));

let metadatas = vec![Some(metadata1), Some(metadata2)];

let request = chroma_types::AddCollectionRecordsRequest::try_new(
tenant_name.clone(),
database_name.clone(),
Expand All @@ -466,11 +535,12 @@ mod tests {
Some(embeddings),
Some(documents),
None,
None,
Some(metadatas),
)
.unwrap();
frontend.add(request).await.unwrap();

// Test count
let count = frontend
.count(
chroma_types::CountRequest::try_new(
Expand All @@ -483,5 +553,46 @@ mod tests {
.await
.unwrap();
assert_eq!(count, 2);

// Test metadata filter
let request = chroma_types::GetRequest::try_new(
tenant_name.clone(),
database_name.clone(),
collection.collection_id,
None,
Some(Where::Metadata(MetadataExpression {
key: "key1".to_string(),
comparison: MetadataComparison::Primitive(
PrimitiveOperator::Equal,
MetadataValue::Str("value1".to_string()),
),
})),
None,
0,
IncludeList::default_get(),
)
.unwrap();
let response = frontend.get(request).await.unwrap();
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0], "id1");

// Test full text index
let request = chroma_types::GetRequest::try_new(
tenant_name.clone(),
database_name.clone(),
collection.collection_id,
None,
Some(Where::Document(DocumentExpression {
operator: chroma_types::DocumentOperator::Contains,
text: "doc2".to_string(),
})),
None,
0,
IncludeList::default_get(),
)
.unwrap();
let response = frontend.get(request).await.unwrap();
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0], "id2");
}
}
12 changes: 6 additions & 6 deletions rust/types/src/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1140,13 +1140,13 @@ impl GetRequest {
#[derive(Clone, Deserialize, Serialize, Debug, ToSchema)]
#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
pub struct GetResponse {
ids: Vec<String>,
embeddings: Option<Vec<Vec<f32>>>,
documents: Option<Vec<Option<String>>>,
uris: Option<Vec<Option<String>>>,
pub ids: Vec<String>,
pub embeddings: Option<Vec<Vec<f32>>>,
pub documents: Option<Vec<Option<String>>>,
pub uris: Option<Vec<Option<String>>>,
// TODO(hammadb): Add metadata & include to the response
metadatas: Option<Vec<Option<Metadata>>>,
include: Vec<Include>,
pub metadatas: Option<Vec<Option<Metadata>>>,
pub include: Vec<Include>,
}

#[cfg(feature = "pyo3")]
Expand Down

0 comments on commit fe7e3d2

Please sign in to comment.