Skip to content

Commit

Permalink
Merge pull request #21 from oramasearch/feat/bool-filter
Browse files Browse the repository at this point in the history
Implement bool index as filter & facets
  • Loading branch information
allevo authored Nov 19, 2024
2 parents 9c6ee92 + b890db7 commit 0d88d14
Show file tree
Hide file tree
Showing 10 changed files with 404 additions and 94 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ members = [
"utils",
"llm",
"number_index",
"bool_index",
]
14 changes: 14 additions & 0 deletions bool_index/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "bool_index"
version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = "1.0.90"
bincode = "1.3.3"
itertools = "0.13.0"
rayon = "1.10.0"
serde = { version = "1.0.210", features = ["derive"] }
types = { path = "../types" }
roaring = "0.10"
dashmap = "6"
98 changes: 98 additions & 0 deletions bool_index/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use std::collections::HashSet;

use dashmap::DashMap;
use types::{DocumentId, FieldId};

#[derive(Debug, Default)]
struct BoolIndexPerField {
true_docs: HashSet<DocumentId>,
false_docs: HashSet<DocumentId>,
}

#[derive(Debug, Default)]
pub struct BoolIndex {
maps: DashMap<FieldId, BoolIndexPerField>,
}

impl BoolIndex {
pub fn new() -> Self {
Self {
maps: Default::default(),
}
}

pub fn add(&self, doc_id: DocumentId, field_id: FieldId, value: bool) {
let mut btree = self.maps.entry(field_id).or_default();
if value {
btree.true_docs.insert(doc_id);
} else {
btree.false_docs.insert(doc_id);
}
}

pub fn filter(&self, field_id: FieldId, val: bool) -> HashSet<DocumentId> {
let btree = match self.maps.get(&field_id) {
Some(btree) => btree,
// This should never happen: if the field is not in the index, it means that the field
// was not indexed, and the filter should not have been created in the first place.
None => return HashSet::new(),
};

if val {
btree.true_docs.clone()
} else {
btree.false_docs.clone()
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use types::{DocumentId, FieldId};

use crate::BoolIndex;

#[test]
fn test_bool_index_filter() {
let index = BoolIndex::new();

index.add(DocumentId(0), FieldId(0), true);
index.add(DocumentId(1), FieldId(0), false);
index.add(DocumentId(2), FieldId(0), true);
index.add(DocumentId(3), FieldId(0), false);
index.add(DocumentId(4), FieldId(0), true);
index.add(DocumentId(5), FieldId(0), false);

let true_docs = index.filter(FieldId(0), true);
assert_eq!(
true_docs,
HashSet::from([DocumentId(0), DocumentId(2), DocumentId(4)])
);

let false_docs = index.filter(FieldId(0), false);
assert_eq!(
false_docs,
HashSet::from([DocumentId(1), DocumentId(3), DocumentId(5)])
);
}

#[test]
fn test_bool_index_filter_unknown_field() {
let index = BoolIndex::new();

index.add(DocumentId(0), FieldId(0), true);
index.add(DocumentId(1), FieldId(0), false);
index.add(DocumentId(2), FieldId(0), true);
index.add(DocumentId(3), FieldId(0), false);
index.add(DocumentId(4), FieldId(0), true);
index.add(DocumentId(5), FieldId(0), false);

let true_docs = index.filter(FieldId(1), true);
assert_eq!(true_docs, HashSet::from([]));

let false_docs = index.filter(FieldId(1), false);
assert_eq!(false_docs, HashSet::from([]));
}
}
1 change: 1 addition & 0 deletions collection_manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ code_index = { path = "../code_index" }
ordered-float = "4.4.0"
num-traits = "0.2"
number_index = { path = "../number_index" }
bool_index = { path = "../bool_index" }

[dev-dependencies]
tempdir = "0.3.7"
58 changes: 51 additions & 7 deletions collection_manager/src/collection.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{
cmp::Reverse,
collections::{BinaryHeap, HashMap},
collections::{BinaryHeap, HashMap, HashSet},
sync::{atomic::AtomicU16, Arc},
};

use anyhow::{anyhow, Result};
use bool_index::BoolIndex;
use code_index::CodeIndex;
use dashmap::DashMap;
use document_storage::DocumentStorage;
Expand All @@ -13,7 +14,7 @@ use nlp::TextParser;
use num_traits::ToPrimitive;
use number_index::NumberIndex;
use ordered_float::NotNan;
use serde_json::{json, Value};
use serde_json::Value;
use storage::Storage;
use string_index::{scorer::bm25::BM25Score, DocumentBatch, StringIndex};
use types::{
Expand Down Expand Up @@ -41,6 +42,8 @@ pub struct Collection {
code_fields: DashMap<String, FieldId>,
// Number
number_index: NumberIndex,
// Bool
bool_index: BoolIndex,
}

impl Collection {
Expand All @@ -66,6 +69,7 @@ impl Collection {
code_index: CodeIndex::new(),
code_fields: Default::default(),
number_index: Default::default(),
bool_index: Default::default(),
};

for (field_name, field_type) in typed_fields {
Expand Down Expand Up @@ -108,6 +112,7 @@ impl Collection {
let mut strings: DocumentBatch = HashMap::with_capacity(document_list.len());
let mut codes: HashMap<_, Vec<_>> = HashMap::with_capacity(document_list.len());
let mut numbers = Vec::new();
let mut bools = Vec::new();
let mut documents = Vec::with_capacity(document_list.len());
for doc in document_list {
let mut flatten = doc.into_flatten();
Expand Down Expand Up @@ -159,7 +164,7 @@ impl Collection {
} else if field_type == ValueType::Scalar(ScalarType::Number) {
let value = match flatten.remove(&key) {
Some(Value::Number(value)) => value,
_ => Err(anyhow!("value is not string. This should never happen"))?,
_ => Err(anyhow!("value is not number. This should never happen"))?,
};

let v: Option<Number> = value
Expand All @@ -175,6 +180,14 @@ impl Collection {

let field_id = self.get_field_id(key.clone());
numbers.push((internal_document_id, field_id, v));
} else if field_type == ValueType::Scalar(ScalarType::Boolean) {
let value = match flatten.remove(&key) {
Some(Value::Bool(value)) => value,
_ => Err(anyhow!("value is not bool. This should never happen"))?,
};

let field_id = self.get_field_id(key.clone());
bools.push((internal_document_id, field_id, value));
}
}

Expand All @@ -190,6 +203,9 @@ impl Collection {
for (doc_id, field_id, value) in numbers {
self.number_index.add(doc_id, field_id, value);
}
for (doc_id, field_id, value) in bools {
self.bool_index.add(doc_id, field_id, value);
}

Ok(())
}
Expand All @@ -212,18 +228,18 @@ impl Collection {

let mut doc_ids = match filter {
Filter::Number(filter_number) => self.number_index.filter(field_id, filter_number),
Filter::Bool(filter_bool) => self.bool_index.filter(field_id, filter_bool),
};
for (field_id, filter) in filters {
let doc_ids_ = match filter {
Filter::Number(filter_number) => {
self.number_index.filter(field_id, filter_number)
}
Filter::Bool(filter_bool) => self.bool_index.filter(field_id, filter_bool),
};
doc_ids = doc_ids.intersection(&doc_ids_).copied().collect();
}

println!("doc_ids: {doc_ids:?}");

Some(doc_ids)
};

Expand Down Expand Up @@ -333,9 +349,12 @@ impl Collection {
let mut values = HashMap::new();

for range in facet.ranges {
let facet = self
let facet: HashSet<_> = self
.number_index
.filter(field_id, NumberFilter::Between(range.from, range.to));
.filter(field_id, NumberFilter::Between(range.from, range.to))
.into_iter()
.filter(|doc_id| token_scores.contains_key(doc_id))
.collect();

values.insert(format!("{}-{}", range.from, range.to), facet.len());
}
Expand All @@ -348,6 +367,31 @@ impl Collection {
},
);
}
FacetDefinition::Bool => {
let true_facet: HashSet<DocumentId> = self
.bool_index
.filter(field_id, true)
.into_iter()
.filter(|doc_id| token_scores.contains_key(doc_id))
.collect();
let false_facet: HashSet<DocumentId> = self
.bool_index
.filter(field_id, false)
.into_iter()
.filter(|doc_id| token_scores.contains_key(doc_id))
.collect();

facets.insert(
field_name,
FacetResult {
count: 2,
values: HashMap::from_iter([
("true".to_string(), true_facet.len()),
("false".to_string(), false_facet.len()),
]),
},
);
}
}
}
Some(facets)
Expand Down
2 changes: 2 additions & 0 deletions collection_manager/src/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl Default for Limit {
#[derive(Debug, Serialize, Deserialize)]
pub enum Filter {
Number(NumberFilter),
Bool(bool),
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -77,6 +78,7 @@ pub struct NumberFacetDefinition {
#[derive(Debug, Serialize, Deserialize)]
pub enum FacetDefinition {
Number(NumberFacetDefinition),
Bool,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
Loading

0 comments on commit 0d88d14

Please sign in to comment.