Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Mar 7, 2025
1 parent 6512fa2 commit aa6f286
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 43 deletions.
249 changes: 226 additions & 23 deletions rust/frontend/src/impls/in_memory_frontend.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use chroma_error::ChromaError;
use chroma_segment::test::TestReferenceSegment;
use chroma_types::operator::{Filter, Limit, Projection, Scan};
use chroma_types::plan::{Count, Get};
use chroma_types::operator::{Filter, KnnBatch, KnnProjection, Limit, Projection, Scan};
use chroma_types::plan::{Count, Get, Knn};
use chroma_types::{
test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Include, Segment,
test_segment, Collection, CollectionAndSegments, CollectionUuid, Database, Include,
IncludeList, Segment,
};
use parking_lot::Mutex;
use std::collections::HashSet;
Expand Down Expand Up @@ -267,7 +268,7 @@ impl InMemoryFrontend {
&mut self,
_request: chroma_types::UpdateCollectionRequest,
) -> Result<chroma_types::UpdateCollectionResponse, chroma_types::UpdateCollectionError> {
Ok(chroma_types::UpdateCollectionResponse {})
unimplemented!()
}

pub async fn delete_collection(
Expand Down Expand Up @@ -338,31 +339,154 @@ impl InMemoryFrontend {

pub async fn update(
&mut self,
_request: chroma_types::UpdateCollectionRecordsRequest,
request: chroma_types::UpdateCollectionRecordsRequest,
) -> Result<
chroma_types::UpdateCollectionRecordsResponse,
chroma_types::UpdateCollectionRecordsError,
> {
let mut inner = self.inner.lock();
let collection = inner
.collections
.iter_mut()
.find(|c| {
c.collection.collection_id == request.collection_id
&& c.collection.tenant == request.tenant_id
&& c.collection.database == request.database_name
})
.ok_or(chroma_types::UpdateCollectionRecordsError::Other(
chroma_types::GetCollectionError::NotFound(request.collection_id.to_string())
.boxed(),
))?;

let chroma_types::UpdateCollectionRecordsRequest {
ids,
embeddings,
documents,
metadatas,
uris,
..
} = request;

let (records, _) = to_records(
ids,
embeddings,
documents,
uris,
metadatas,
chroma_types::Operation::Update,
)
.map_err(|e| e.boxed())?;

collection
.reference_impl
.apply_operation_records(records, collection.metadata_segment.id);

Ok(chroma_types::UpdateCollectionRecordsResponse {})
}

pub async fn upsert(
&mut self,
_request: chroma_types::UpsertCollectionRecordsRequest,
request: chroma_types::UpsertCollectionRecordsRequest,
) -> Result<
chroma_types::UpsertCollectionRecordsResponse,
chroma_types::UpsertCollectionRecordsError,
> {
let mut inner = self.inner.lock();
let collection = inner
.collections
.iter_mut()
.find(|c| {
c.collection.collection_id == request.collection_id
&& c.collection.tenant == request.tenant_id
&& c.collection.database == request.database_name
})
.ok_or(chroma_types::UpsertCollectionRecordsError::Other(
chroma_types::GetCollectionError::NotFound(request.collection_id.to_string())
.boxed(),
))?;

let chroma_types::UpsertCollectionRecordsRequest {
ids,
embeddings,
documents,
metadatas,
uris,
..
} = request;

let embeddings = embeddings.map(|embeddings| embeddings.into_iter().map(Some).collect());

let (records, _) = to_records(
ids,
embeddings,
documents,
uris,
metadatas,
chroma_types::Operation::Add,
)
.map_err(|e| e.boxed())?;

collection
.reference_impl
.apply_operation_records(records, collection.metadata_segment.id);

Ok(chroma_types::UpsertCollectionRecordsResponse {})
}

pub async fn delete(
&mut self,
_request: chroma_types::DeleteCollectionRecordsRequest,
request: chroma_types::DeleteCollectionRecordsRequest,
) -> Result<
chroma_types::DeleteCollectionRecordsResponse,
chroma_types::DeleteCollectionRecordsError,
> {
let ids_to_delete = self
.get(
chroma_types::GetRequest::try_new(
request.tenant_id.clone(),
request.database_name.clone(),
request.collection_id,
request.ids,
request.r#where,
None,
0,
IncludeList::empty(),
)
.unwrap(),
)
.await
.map_err(|e| e.boxed())
.map(|response| response.ids)?;

let mut inner = self.inner.lock();
let collection = inner
.collections
.iter_mut()
.find(|c| {
c.collection.collection_id == request.collection_id
&& c.collection.tenant == request.tenant_id
&& c.collection.database == request.database_name
})
.ok_or(chroma_types::DeleteCollectionRecordsError::Internal(
chroma_types::GetCollectionError::NotFound(request.collection_id.to_string())
.boxed(),
))?;

let records = ids_to_delete
.into_iter()
.map(|id| chroma_types::OperationRecord {
id,
operation: chroma_types::Operation::Delete,
encoding: None,
embedding: None,
document: None,
metadata: None,
})
.collect::<Vec<_>>();
collection
.reference_impl
.apply_operation_records(records, collection.metadata_segment.id);

Ok(chroma_types::DeleteCollectionRecordsResponse {})
}

Expand Down Expand Up @@ -464,14 +588,71 @@ impl InMemoryFrontend {

pub async fn query(
&mut self,
_request: chroma_types::QueryRequest,
request: chroma_types::QueryRequest,
) -> Result<chroma_types::QueryResponse, chroma_types::QueryError> {
todo!()
let inner = self.inner.lock();
let collection = inner
.collections
.iter()
.find(|c| {
c.collection.collection_id == request.collection_id
&& c.collection.tenant == request.tenant_id
&& c.collection.database == request.database_name
})
.ok_or(
chroma_types::GetCollectionError::NotFound(request.collection_id.to_string())
.boxed(),
)?;

let chroma_types::QueryRequest {
r#where,
include,
ids,
embeddings,
n_results,
..
} = request;

let filter = Filter {
query_ids: ids,
where_clause: r#where,
};

let query_response = collection
.reference_impl
.knn(Knn {
scan: Scan {
collection_and_segments: CollectionAndSegments {
collection: collection.collection.clone(),
metadata_segment: collection.metadata_segment.clone(),
vector_segment: collection.vector_segment.clone(),
record_segment: collection.record_segment.clone(),
},
},
filter,
knn: KnnBatch {
embeddings,
fetch: n_results,
},
proj: KnnProjection {
projection: Projection {
document: include.0.contains(&Include::Document),
embedding: include.0.contains(&Include::Embedding),
// If URI is requested, metadata is also requested so we can extract the URI.
metadata: (include.0.contains(&Include::Metadata)
|| include.0.contains(&Include::Uri)),
},
distance: include.0.contains(&Include::Distance),
},
})
.map_err(|e| e.boxed())?;

Ok((query_response, include).into())
}

pub async fn healthcheck(&self) -> chroma_types::HealthCheckResponse {
chroma_types::HealthCheckResponse {
is_executor_ready: true, // Example placeholder
is_executor_ready: true,
}
}
}
Expand All @@ -485,8 +666,7 @@ mod tests {

use super::*;

#[tokio::test]
async fn test_collection() {
async fn create_test_collection() -> (InMemoryFrontend, Collection) {
let tenant_name = "test".to_string();
let database_name = "test".to_string();
let collection_name = "test".to_string();
Expand All @@ -511,10 +691,17 @@ mod tests {
false,
)
.unwrap();
let collection = frontend.create_collection(request).await.unwrap();
(
frontend.clone(),
frontend.create_collection(request).await.unwrap(),
)
}

#[tokio::test]
async fn test_collection_get_query() {
let (mut frontend, collection) = create_test_collection().await;
let ids = vec!["id1".to_string(), "id2".to_string()];
let embeddings = vec![vec![1.0, 1.0, 1.0], vec![2.0, 2.0, 2.0]];
let embeddings = vec![vec![-1.0, -1.0, -1.0], vec![0.0, 0.0, 0.0]];
let documents = vec![Some("doc1".to_string()), Some("doc2".to_string())];

let mut metadata1 = Metadata::new();
Expand All @@ -528,8 +715,8 @@ mod tests {
let metadatas = vec![Some(metadata1), Some(metadata2)];

let request = chroma_types::AddCollectionRecordsRequest::try_new(
tenant_name.clone(),
database_name.clone(),
collection.tenant.clone(),
collection.database.clone(),
collection.collection_id,
ids,
Some(embeddings),
Expand All @@ -544,8 +731,8 @@ mod tests {
let count = frontend
.count(
chroma_types::CountRequest::try_new(
tenant_name.clone(),
database_name.clone(),
collection.tenant.clone(),
collection.database.clone(),
collection.collection_id,
)
.unwrap(),
Expand All @@ -556,8 +743,8 @@ mod tests {

// Test metadata filter
let request = chroma_types::GetRequest::try_new(
tenant_name.clone(),
database_name.clone(),
collection.tenant.clone(),
collection.database.clone(),
collection.collection_id,
None,
Some(Where::Metadata(MetadataExpression {
Expand All @@ -576,10 +763,10 @@ mod tests {
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0], "id1");

// Test full text index
// Test full text query
let request = chroma_types::GetRequest::try_new(
tenant_name.clone(),
database_name.clone(),
collection.tenant.clone(),
collection.database.clone(),
collection.collection_id,
None,
Some(Where::Document(DocumentExpression {
Expand All @@ -594,5 +781,21 @@ mod tests {
let response = frontend.get(request).await.unwrap();
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0], "id2");

// Test vector query
let request = chroma_types::QueryRequest::try_new(
collection.tenant.clone(),
collection.database.clone(),
collection.collection_id,
None,
None,
vec![vec![0.5, 0.5, 0.5]],
10,
IncludeList::default_query(),
)
.unwrap();
let response = frontend.query(request).await.unwrap();
assert_eq!(response.ids[0].len(), 2);
assert_eq!(response.ids[0], vec!["id2", "id1"]);
}
}
Loading

0 comments on commit aa6f286

Please sign in to comment.