From b890db773578a66586fa483359fe26d1b9351b8a Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 19 Nov 2024 13:00:39 +0100 Subject: [PATCH] Implement bool index as filter & facets --- Cargo.lock | 15 ++ Cargo.toml | 1 + bool_index/Cargo.toml | 14 ++ bool_index/src/lib.rs | 98 +++++++++ collection_manager/Cargo.toml | 1 + collection_manager/src/collection.rs | 58 +++++- collection_manager/src/dto.rs | 2 + collection_manager/src/lib.rs | 301 +++++++++++++++++++-------- rustorama/src/main.rs | 7 +- web_server/src/lib.rs | 1 - 10 files changed, 404 insertions(+), 94 deletions(-) create mode 100644 bool_index/Cargo.toml create mode 100644 bool_index/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index bf820a1..a03a998 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -855,6 +855,20 @@ dependencies = [ "piper", ] +[[package]] +name = "bool_index" +version = "0.1.0" +dependencies = [ + "anyhow", + "bincode", + "dashmap", + "itertools 0.13.0", + "rayon", + "roaring", + "serde", + "types", +] + [[package]] name = "brotli" version = "6.0.0" @@ -1236,6 +1250,7 @@ name = "collection_manager" version = "0.1.0" dependencies = [ "anyhow", + "bool_index", "code_index", "code_parser", "cuid", diff --git a/Cargo.toml b/Cargo.toml index 55bf67c..72d4593 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,4 +17,5 @@ members = [ "utils", "llm", "number_index", + "bool_index", ] diff --git a/bool_index/Cargo.toml b/bool_index/Cargo.toml new file mode 100644 index 0000000..95212ab --- /dev/null +++ b/bool_index/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "bool_index" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.90" +bincode = "1.3.3" +itertools = "0.13.0" +rayon = "1.10.0" +serde = { version = "1.0.210", features = ["derive"] } +types = { path = "../types" } +roaring = "0.10" +dashmap = "6" diff --git a/bool_index/src/lib.rs b/bool_index/src/lib.rs new file mode 100644 index 0000000..b616cd4 --- /dev/null +++ b/bool_index/src/lib.rs @@ -0,0 +1,98 @@ +use std::collections::HashSet; + +use dashmap::DashMap; +use types::{DocumentId, FieldId}; + +#[derive(Debug, Default)] +struct BoolIndexPerField { + true_docs: HashSet, + false_docs: HashSet, +} + +#[derive(Debug, Default)] +pub struct BoolIndex { + maps: DashMap, +} + +impl BoolIndex { + pub fn new() -> Self { + Self { + maps: Default::default(), + } + } + + pub fn add(&self, doc_id: DocumentId, field_id: FieldId, value: bool) { + let mut btree = self.maps.entry(field_id).or_default(); + if value { + btree.true_docs.insert(doc_id); + } else { + btree.false_docs.insert(doc_id); + } + } + + pub fn filter(&self, field_id: FieldId, val: bool) -> HashSet { + let btree = match self.maps.get(&field_id) { + Some(btree) => btree, + // This should never happen: if the field is not in the index, it means that the field + // was not indexed, and the filter should not have been created in the first place. + None => return HashSet::new(), + }; + + if val { + btree.true_docs.clone() + } else { + btree.false_docs.clone() + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use types::{DocumentId, FieldId}; + + use crate::BoolIndex; + + #[test] + fn test_bool_index_filter() { + let index = BoolIndex::new(); + + index.add(DocumentId(0), FieldId(0), true); + index.add(DocumentId(1), FieldId(0), false); + index.add(DocumentId(2), FieldId(0), true); + index.add(DocumentId(3), FieldId(0), false); + index.add(DocumentId(4), FieldId(0), true); + index.add(DocumentId(5), FieldId(0), false); + + let true_docs = index.filter(FieldId(0), true); + assert_eq!( + true_docs, + HashSet::from([DocumentId(0), DocumentId(2), DocumentId(4)]) + ); + + let false_docs = index.filter(FieldId(0), false); + assert_eq!( + false_docs, + HashSet::from([DocumentId(1), DocumentId(3), DocumentId(5)]) + ); + } + + #[test] + fn test_bool_index_filter_unknown_field() { + let index = BoolIndex::new(); + + index.add(DocumentId(0), FieldId(0), true); + index.add(DocumentId(1), FieldId(0), false); + index.add(DocumentId(2), FieldId(0), true); + index.add(DocumentId(3), FieldId(0), false); + index.add(DocumentId(4), FieldId(0), true); + index.add(DocumentId(5), FieldId(0), false); + + let true_docs = index.filter(FieldId(1), true); + assert_eq!(true_docs, HashSet::from([])); + + let false_docs = index.filter(FieldId(1), false); + assert_eq!(false_docs, HashSet::from([])); + } +} diff --git a/collection_manager/Cargo.toml b/collection_manager/Cargo.toml index c2c8506..4f96196 100644 --- a/collection_manager/Cargo.toml +++ b/collection_manager/Cargo.toml @@ -20,6 +20,7 @@ code_index = { path = "../code_index" } ordered-float = "4.4.0" num-traits = "0.2" number_index = { path = "../number_index" } +bool_index = { path = "../bool_index" } [dev-dependencies] tempdir = "0.3.7" diff --git a/collection_manager/src/collection.rs b/collection_manager/src/collection.rs index befbc8c..2e1bcdc 100644 --- a/collection_manager/src/collection.rs +++ b/collection_manager/src/collection.rs @@ -1,10 +1,11 @@ use std::{ cmp::Reverse, - collections::{BinaryHeap, HashMap}, + collections::{BinaryHeap, HashMap, HashSet}, sync::{atomic::AtomicU16, Arc}, }; use anyhow::{anyhow, Result}; +use bool_index::BoolIndex; use code_index::CodeIndex; use dashmap::DashMap; use document_storage::DocumentStorage; @@ -13,7 +14,7 @@ use nlp::TextParser; use num_traits::ToPrimitive; use number_index::NumberIndex; use ordered_float::NotNan; -use serde_json::{json, Value}; +use serde_json::Value; use storage::Storage; use string_index::{scorer::bm25::BM25Score, DocumentBatch, StringIndex}; use types::{ @@ -41,6 +42,8 @@ pub struct Collection { code_fields: DashMap, // Number number_index: NumberIndex, + // Bool + bool_index: BoolIndex, } impl Collection { @@ -66,6 +69,7 @@ impl Collection { code_index: CodeIndex::new(), code_fields: Default::default(), number_index: Default::default(), + bool_index: Default::default(), }; for (field_name, field_type) in typed_fields { @@ -108,6 +112,7 @@ impl Collection { let mut strings: DocumentBatch = HashMap::with_capacity(document_list.len()); let mut codes: HashMap<_, Vec<_>> = HashMap::with_capacity(document_list.len()); let mut numbers = Vec::new(); + let mut bools = Vec::new(); let mut documents = Vec::with_capacity(document_list.len()); for doc in document_list { let mut flatten = doc.into_flatten(); @@ -159,7 +164,7 @@ impl Collection { } else if field_type == ValueType::Scalar(ScalarType::Number) { let value = match flatten.remove(&key) { Some(Value::Number(value)) => value, - _ => Err(anyhow!("value is not string. This should never happen"))?, + _ => Err(anyhow!("value is not number. This should never happen"))?, }; let v: Option = value @@ -175,6 +180,14 @@ impl Collection { let field_id = self.get_field_id(key.clone()); numbers.push((internal_document_id, field_id, v)); + } else if field_type == ValueType::Scalar(ScalarType::Boolean) { + let value = match flatten.remove(&key) { + Some(Value::Bool(value)) => value, + _ => Err(anyhow!("value is not bool. This should never happen"))?, + }; + + let field_id = self.get_field_id(key.clone()); + bools.push((internal_document_id, field_id, value)); } } @@ -190,6 +203,9 @@ impl Collection { for (doc_id, field_id, value) in numbers { self.number_index.add(doc_id, field_id, value); } + for (doc_id, field_id, value) in bools { + self.bool_index.add(doc_id, field_id, value); + } Ok(()) } @@ -212,18 +228,18 @@ impl Collection { let mut doc_ids = match filter { Filter::Number(filter_number) => self.number_index.filter(field_id, filter_number), + Filter::Bool(filter_bool) => self.bool_index.filter(field_id, filter_bool), }; for (field_id, filter) in filters { let doc_ids_ = match filter { Filter::Number(filter_number) => { self.number_index.filter(field_id, filter_number) } + Filter::Bool(filter_bool) => self.bool_index.filter(field_id, filter_bool), }; doc_ids = doc_ids.intersection(&doc_ids_).copied().collect(); } - println!("doc_ids: {doc_ids:?}"); - Some(doc_ids) }; @@ -333,9 +349,12 @@ impl Collection { let mut values = HashMap::new(); for range in facet.ranges { - let facet = self + let facet: HashSet<_> = self .number_index - .filter(field_id, NumberFilter::Between(range.from, range.to)); + .filter(field_id, NumberFilter::Between(range.from, range.to)) + .into_iter() + .filter(|doc_id| token_scores.contains_key(doc_id)) + .collect(); values.insert(format!("{}-{}", range.from, range.to), facet.len()); } @@ -348,6 +367,31 @@ impl Collection { }, ); } + FacetDefinition::Bool => { + let true_facet: HashSet = self + .bool_index + .filter(field_id, true) + .into_iter() + .filter(|doc_id| token_scores.contains_key(doc_id)) + .collect(); + let false_facet: HashSet = self + .bool_index + .filter(field_id, false) + .into_iter() + .filter(|doc_id| token_scores.contains_key(doc_id)) + .collect(); + + facets.insert( + field_name, + FacetResult { + count: 2, + values: HashMap::from_iter([ + ("true".to_string(), true_facet.len()), + ("false".to_string(), false_facet.len()), + ]), + }, + ); + } } } Some(facets) diff --git a/collection_manager/src/dto.rs b/collection_manager/src/dto.rs index be45ba4..6871158 100644 --- a/collection_manager/src/dto.rs +++ b/collection_manager/src/dto.rs @@ -61,6 +61,7 @@ impl Default for Limit { #[derive(Debug, Serialize, Deserialize)] pub enum Filter { Number(NumberFilter), + Bool(bool), } #[derive(Debug, Serialize, Deserialize)] @@ -77,6 +78,7 @@ pub struct NumberFacetDefinition { #[derive(Debug, Serialize, Deserialize)] pub enum FacetDefinition { Number(NumberFacetDefinition), + Bool, } #[derive(Debug, Serialize, Deserialize)] diff --git a/collection_manager/src/lib.rs b/collection_manager/src/lib.rs index f630689..4b0be19 100644 --- a/collection_manager/src/lib.rs +++ b/collection_manager/src/lib.rs @@ -88,11 +88,11 @@ mod tests { use serde_json::json; use storage::Storage; use tempdir::TempDir; - use types::{CodeLanguage, Number, NumberFilter}; + use types::{Number, NumberFilter}; use crate::dto::{ CreateCollectionOptionDTO, FacetDefinition, Filter, Limit, NumberFacetDefinition, - NumberFacetDefinitionRange, SearchParams, TypedField, + NumberFacetDefinitionRange, SearchParams, }; use super::CollectionManager; @@ -374,88 +374,6 @@ mod tests { assert_eq!(output.hits[4].id, "95"); } - #[test] - fn test_foo() { - let manager = create_manager(); - let collection_id_str = "my-test-collection".to_string(); - - let collection_id = manager - .create_collection(CreateCollectionOptionDTO { - id: collection_id_str.clone(), - description: Some("Collection of songs".to_string()), - language: None, - typed_fields: vec![("code".to_string(), TypedField::Code(CodeLanguage::TSX))] - .into_iter() - .collect(), - }) - .expect("insertion should be successful"); - - manager.get(collection_id.clone(), |collection| { - collection.insert_batch( - vec![ - json!({ - "id": "1", - "code": r#" -import { TableController, type SortingState } from '@tanstack/lit-table' -//... -@state() -private _sorting: SortingState = [ - { - id: 'age', //you should get autocomplete for the `id` and `desc` properties - desc: true, - } -] -"#, - }), - json!({ - "id": "2", - "code": r#"export type RowSelectionState = Record - -export type RowSelectionTableState = { - rowSelection: RowSelectionState -}"#, - }), - json!({ - "id": "3", - "code": r#"initialState?: Partial< - VisibilityTableState & - ColumnOrderTableState & - ColumnPinningTableState & - FiltersTableState & - SortingTableState & - ExpandedTableState & - GroupingTableState & - ColumnSizingTableState & - PaginationTableState & - RowSelectionTableState ->"#, - }), - json!({ - "id": "4", - "code": r#"setColumnVisibility: (updater: Updater) => void"#, - }) - ] - .try_into() - .unwrap(), - ) - }); - - let output = manager - .get(collection_id, |collection| { - let search_params = SearchParams { - term: "SelectionTableState".to_string(), - limit: Limit(10), - boost: Default::default(), - properties: Default::default(), - where_filter: Default::default(), - facets: Default::default(), - }; - collection.search(search_params) - }) - .unwrap() - .unwrap(); - } - #[test] fn test_filter_number() { let manager = create_manager(); @@ -615,4 +533,219 @@ export type RowSelectionTableState = { ]) ); } + + #[test] + fn test_filter_bool() { + let manager = create_manager(); + let collection_id_str = "my-test-collection".to_string(); + + let collection_id = manager + .create_collection(CreateCollectionOptionDTO { + id: collection_id_str.clone(), + description: Some("Collection of songs".to_string()), + language: None, + typed_fields: Default::default(), + }) + .expect("insertion should be successful"); + + manager + .get(collection_id.clone(), |collection| { + collection.insert_batch( + (0..100) + .map(|i| { + json!({ + "id": i.to_string(), + "text": "text ".repeat(i + 1), + "bool": i % 2 == 0, + }) + }) + .collect::>() + .try_into() + .unwrap(), + ) + }) + .unwrap() + .unwrap(); + + let output = manager + .get(collection_id.clone(), |collection| { + collection.search(SearchParams { + term: "text".to_string(), + limit: Limit(10), + boost: Default::default(), + properties: Default::default(), + where_filter: vec![("bool".to_string(), Filter::Bool(true))] + .into_iter() + .collect(), + facets: Default::default(), + }) + }) + .unwrap() + .unwrap(); + + assert_eq!(output.count, 50); + assert_eq!(output.hits.len(), 10); + for hit in output.hits.iter() { + let id = hit.id.parse::().unwrap(); + assert_eq!(id % 2, 0); + } + } + + #[test] + fn test_facets_bool() { + let manager = create_manager(); + let collection_id_str = "my-test-collection".to_string(); + + let collection_id = manager + .create_collection(CreateCollectionOptionDTO { + id: collection_id_str.clone(), + description: Some("Collection of songs".to_string()), + language: None, + typed_fields: Default::default(), + }) + .expect("insertion should be successful"); + + manager + .get(collection_id.clone(), |collection| { + collection.insert_batch( + (0..100) + .map(|i| { + json!({ + "id": i.to_string(), + "text": "text ".repeat(i + 1), + "bool": i % 2 == 0, + }) + }) + .collect::>() + .try_into() + .unwrap(), + ) + }) + .unwrap() + .unwrap(); + + let output = manager + .get(collection_id.clone(), |collection| { + collection.search(SearchParams { + term: "text".to_string(), + limit: Limit(10), + boost: Default::default(), + properties: Default::default(), + where_filter: Default::default(), + facets: HashMap::from_iter(vec![("bool".to_string(), FacetDefinition::Bool)]), + }) + }) + .unwrap() + .unwrap(); + + let facets = output.facets.expect("Facet should be there"); + let bool_facet = facets + .get("bool") + .expect("Facet on field 'bool' should be there"); + + assert_eq!(bool_facet.count, 2); + assert_eq!(bool_facet.values.len(), 2); + + assert_eq!( + bool_facet.values, + HashMap::from_iter(vec![("true".to_string(), 50), ("false".to_string(), 50),]) + ); + } + + #[test] + fn test_facets_should_based_on_term() { + let manager = create_manager(); + let collection_id_str = "my-test-collection".to_string(); + + let collection_id = manager + .create_collection(CreateCollectionOptionDTO { + id: collection_id_str.clone(), + description: Some("Collection of songs".to_string()), + language: None, + typed_fields: Default::default(), + }) + .expect("insertion should be successful"); + + manager + .get(collection_id.clone(), |collection| { + collection.insert_batch( + vec![ + json!({ + "id": "1", + "text": "text", + "bool": true, + "number": 1, + }), + json!({ + "id": "2", + "text": "text text", + "bool": false, + "number": 2, + }), + // This document doens't match the term + // so it should not be counted in the facets + json!({ + "id": "3", + "text": "another", + "bool": true, + "number": 1, + }), + ] + .try_into() + .unwrap(), + ) + }) + .unwrap() + .unwrap(); + + let output = manager + .get(collection_id.clone(), |collection| { + collection.search(SearchParams { + term: "text".to_string(), + limit: Limit(10), + boost: Default::default(), + properties: Default::default(), + where_filter: Default::default(), + facets: HashMap::from_iter(vec![ + ("bool".to_string(), FacetDefinition::Bool), + ( + "number".to_string(), + FacetDefinition::Number(NumberFacetDefinition { + ranges: vec![NumberFacetDefinitionRange { + from: Number::from(0), + to: Number::from(10), + }], + }), + ), + ]), + }) + }) + .unwrap() + .unwrap(); + + let facets = output.facets.expect("Facet should be there"); + let bool_facet = facets + .get("bool") + .expect("Facet on field 'bool' should be there"); + + assert_eq!(bool_facet.count, 2); + assert_eq!(bool_facet.values.len(), 2); + + assert_eq!( + bool_facet.values, + HashMap::from_iter(vec![("true".to_string(), 1), ("false".to_string(), 1),]) + ); + + let number_facet = facets + .get("number") + .expect("Facet on field 'number' should be there"); + + assert_eq!(number_facet.count, 1); + assert_eq!(number_facet.values.len(), 1); + + assert_eq!( + number_facet.values, + HashMap::from_iter(vec![("0-10".to_string(), 2),]) + ); + } } diff --git a/rustorama/src/main.rs b/rustorama/src/main.rs index d085c16..874727b 100644 --- a/rustorama/src/main.rs +++ b/rustorama/src/main.rs @@ -37,7 +37,10 @@ async fn start(config: RustoramaConfig) -> Result<()> { let manager = Arc::new(manager); let web_server = WebServer::new(manager); - println!("Starting web server on {}:{}", config.http.host, config.http.port); + println!( + "Starting web server on {}:{}", + config.http.host, config.http.port + ); web_server.start(config.http).await?; Ok(()) @@ -123,4 +126,4 @@ mod tests { }; } } -*/ \ No newline at end of file +*/ diff --git a/web_server/src/lib.rs b/web_server/src/lib.rs index f2499fb..09c4c45 100644 --- a/web_server/src/lib.rs +++ b/web_server/src/lib.rs @@ -29,7 +29,6 @@ impl WebServer { pub async fn start(self, config: HttpConfig) -> Result<()> { let addr = SocketAddr::new(config.host, config.port); - let router = api::api_config().with_state(self.collection_manager.clone()); let router = if config.allow_cors { let cors_layer = CorsLayer::new()