Skip to content

Commit

Permalink
Fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
allevo committed Nov 18, 2024
1 parent 4989ef3 commit aa278fd
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 87 deletions.
3 changes: 1 addition & 2 deletions code_index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::collections::HashSet;
use std::{collections::HashMap, sync::RwLock};

use anyhow::Result;
use code_parser::treesitter::{FunctionDeclaration, ImportedTokens, JsxElement, NewParser};
use code_parser::treesitter::CodeToken;
use code_parser::treesitter::{FunctionDeclaration, ImportedTokens, JsxElement, NewParser};
use nlp::tokenizer::Tokenizer;
use ptrie::Trie;
use regex::Regex;
Expand Down Expand Up @@ -87,7 +87,6 @@ impl CodeIndex {

if let Some(c) = exact_match {
for (doc_id, code_posting) in c {

if let Some(filtered_doc_ids) = filtered_doc_ids {
if !filtered_doc_ids.contains(doc_id) {
continue;
Expand Down
2 changes: 1 addition & 1 deletion code_parser/src/treesitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,4 +857,4 @@ const a = <th
assert_eq!(output, vec![]);
}
}
*/
*/
23 changes: 13 additions & 10 deletions collection_manager/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use serde_json::Value;
use storage::Storage;
use string_index::{scorer::bm25::BM25Score, DocumentBatch, StringIndex};
use types::{
CollectionId, Document, DocumentId, DocumentList, FieldId, Number, ScalarType, SearchResult, SearchResultHit, StringParser, TokenScore, ValueType
CollectionId, Document, DocumentId, DocumentList, FieldId, Number, ScalarType, SearchResult,
SearchResultHit, StringParser, TokenScore, ValueType,
};

use crate::dto::{CollectionDTO, Filter, SearchParams, TypedField};
Expand Down Expand Up @@ -161,11 +162,11 @@ impl Collection {
_ => Err(anyhow!("value is not string. This should never happen"))?,
};

let v: Option<Number> = value.as_i64()
let v: Option<Number> = value
.as_i64()
.and_then(|v| v.to_i32())
.map(Number::from)
.or_else(|| value.as_f64().and_then(|v| v.to_f32()).map(Number::from))
;
.or_else(|| value.as_f64().and_then(|v| v.to_f32()).map(Number::from));
let v = match v {
Some(v) => v,
// TODO: handle better the error
Expand Down Expand Up @@ -197,18 +198,20 @@ impl Collection {
let filtered_doc_ids = if search_params.where_filter.is_empty() {
None
} else {
let mut filters: Vec<_> = search_params.where_filter
let mut filters: Vec<_> = search_params
.where_filter
.into_iter()
.map(|(field_name, value)| {
let field_id = self.get_field_id(field_name);
(field_id, value)
}).collect();
let (field_id, filter) = filters.pop().expect("where condition has to not be empty here.");
})
.collect();
let (field_id, filter) = filters
.pop()
.expect("where condition has to not be empty here.");

let mut doc_ids = match filter {
Filter::Number(filter_number) => {
self.number_index.filter(field_id, filter_number)
}
Filter::Number(filter_number) => self.number_index.filter(field_id, filter_number),
};
for (field_id, filter) in filters {
let doc_ids_ = match filter {
Expand Down
55 changes: 33 additions & 22 deletions collection_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,31 +460,42 @@ export type RowSelectionTableState = {
})
.expect("insertion should be successful");

manager.get(collection_id.clone(), |collection| {
collection.insert_batch(
(0..100)
.map(|i| {
json!({
"id": i.to_string(),
"text": "text ".repeat(i + 1),
"number": i,
manager
.get(collection_id.clone(), |collection| {
collection.insert_batch(
(0..100)
.map(|i| {
json!({
"id": i.to_string(),
"text": "text ".repeat(i + 1),
"number": i,
})
})
})
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
}).unwrap().unwrap();
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
})
.unwrap()
.unwrap();

let output = manager.get(collection_id.clone(), |collection| {
collection.search(SearchParams {
term: "text".to_string(),
limit: Limit(10),
boost: Default::default(),
properties: Default::default(),
where_filter: vec![("number".to_string(), Filter::Number(NumberFilter::Equal(50.into())))].into_iter().collect(),
let output = manager
.get(collection_id.clone(), |collection| {
collection.search(SearchParams {
term: "text".to_string(),
limit: Limit(10),
boost: Default::default(),
properties: Default::default(),
where_filter: vec![(
"number".to_string(),
Filter::Number(NumberFilter::Equal(50.into())),
)]
.into_iter()
.collect(),
})
})
}).unwrap().unwrap();
.unwrap()
.unwrap();

assert_eq!(output.count, 1);
assert_eq!(output.hits.len(), 1);
Expand Down
1 change: 0 additions & 1 deletion llm/src/questions_generation/prompts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

pub const QUESTIONS_GENERATION_SYSTEM_PROMPT: &str = r#"
Pretend you're a user searching on Google, a forum, or a blog. Your task is to generate a list of questions that relates to the the context (### Context).
Expand Down
97 changes: 59 additions & 38 deletions number_index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::collections::{BTreeMap, HashSet};
use dashmap::DashMap;
use types::{DocumentId, FieldId, Number, NumberFilter};



#[derive(Debug, Default)]
pub struct NumberIndex {
maps: DashMap<FieldId, BTreeMap<Number, Vec<DocumentId>>>,
Expand Down Expand Up @@ -41,36 +39,30 @@ impl NumberIndex {
HashSet::new()
}
}
NumberFilter::LessThan(value) => {
btree.range((Bound::Unbounded, Bound::Excluded(&value)))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect()
}
NumberFilter::LessThanOrEqual(value) => {
btree.range((Bound::Unbounded, Bound::Included(&value)))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect()
}
NumberFilter::GreaterThan(value) => {
btree.range((Bound::Excluded(&value), Bound::Unbounded))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect()
}
NumberFilter::GreaterThanOrEqual(value) => {
btree.range((Bound::Included(&value), Bound::Unbounded))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect()
}
NumberFilter::Between(min, max) => {
btree.range((Bound::Included(&min), Bound::Included(&max)))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect()
}
NumberFilter::LessThan(value) => btree
.range((Bound::Unbounded, Bound::Excluded(&value)))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect(),
NumberFilter::LessThanOrEqual(value) => btree
.range((Bound::Unbounded, Bound::Included(&value)))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect(),
NumberFilter::GreaterThan(value) => btree
.range((Bound::Excluded(&value), Bound::Unbounded))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect(),
NumberFilter::GreaterThanOrEqual(value) => btree
.range((Bound::Included(&value), Bound::Unbounded))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect(),
NumberFilter::Between(min, max) => btree
.range((Bound::Included(&min), Bound::Included(&max)))
.flat_map(|(_, doc_ids)| doc_ids.iter().cloned())
.collect(),
}
}
}


#[cfg(test)]
mod tests {
use core::f32;
Expand Down Expand Up @@ -176,7 +168,8 @@ mod tests {
assert!(a > Number::from(f32::NEG_INFINITY));
assert!(a == Number::from(f32::NAN));

let v = [Number::from(1),
let v = [
Number::from(1),
Number::from(1.0),
Number::from(2),
Number::from(2.0),
Expand All @@ -186,7 +179,8 @@ mod tests {
Number::from(-2.0),
Number::from(f32::INFINITY),
Number::from(f32::NEG_INFINITY),
Number::from(f32::NAN)];
Number::from(f32::NAN),
];

for i in 0..v.len() {
for j in 0..v.len() {
Expand All @@ -197,7 +191,6 @@ mod tests {
assert_eq!(way.reverse(), other_way);
}
}

}

macro_rules! test_number_filter {
Expand All @@ -217,31 +210,59 @@ mod tests {

a(index);
}
}
};
}

test_number_filter!(test_number_index_filter_eq, |index: NumberIndex| {
let output = index.filter(FieldId(0), NumberFilter::Equal(2.into()));
assert_eq!(output, HashSet::from_iter(vec![DocumentId(2), DocumentId(5)]));
assert_eq!(
output,
HashSet::from_iter(vec![DocumentId(2), DocumentId(5)])
);
});
test_number_filter!(test_number_index_filter_lt, |index: NumberIndex| {
let output = index.filter(FieldId(0), NumberFilter::LessThan(2.into()));
assert_eq!(output, HashSet::from_iter(vec![DocumentId(0), DocumentId(1)]));
assert_eq!(
output,
HashSet::from_iter(vec![DocumentId(0), DocumentId(1)])
);
});
test_number_filter!(test_number_index_filter_lt_equal, |index: NumberIndex| {
let output = index.filter(FieldId(0), NumberFilter::LessThanOrEqual(2.into()));
assert_eq!(output, HashSet::from_iter(vec![DocumentId(0), DocumentId(1), DocumentId(2), DocumentId(5)]));
assert_eq!(
output,
HashSet::from_iter(vec![
DocumentId(0),
DocumentId(1),
DocumentId(2),
DocumentId(5)
])
);
});
test_number_filter!(test_number_index_filter_gt, |index: NumberIndex| {
let output = index.filter(FieldId(0), NumberFilter::GreaterThan(2.into()));
assert_eq!(output, HashSet::from_iter(vec![DocumentId(3), DocumentId(4)]));
assert_eq!(
output,
HashSet::from_iter(vec![DocumentId(3), DocumentId(4)])
);
});
test_number_filter!(test_number_index_filter_gt_equal, |index: NumberIndex| {
let output = index.filter(FieldId(0), NumberFilter::GreaterThanOrEqual(2.into()));
assert_eq!(output, HashSet::from_iter(vec![DocumentId(3), DocumentId(4), DocumentId(2), DocumentId(5)]));
assert_eq!(
output,
HashSet::from_iter(vec![
DocumentId(3),
DocumentId(4),
DocumentId(2),
DocumentId(5)
])
);
});
test_number_filter!(test_number_index_filter_between, |index: NumberIndex| {
let output = index.filter(FieldId(0), NumberFilter::Between(2.into(), 3.into()));
assert_eq!(output, HashSet::from_iter(vec![DocumentId(3), DocumentId(2), DocumentId(5)]));
assert_eq!(
output,
HashSet::from_iter(vec![DocumentId(3), DocumentId(2), DocumentId(5)])
);
});
}
1 change: 0 additions & 1 deletion tanstack_example/src/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ pub async fn parse_example(path: &str) -> Vec<Value> {
.collect();

// Await all the futures


futures::future::join_all(futures).await
}
Expand Down
17 changes: 5 additions & 12 deletions types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ pub trait StringParser: Send + Sync {
fn tokenize_str_and_stem(&self, input: &str) -> Result<Vec<(String, Vec<String>)>>;
}


#[derive(Debug, Serialize, Deserialize)]
pub enum NumberFilter {
Equal(Number),
Expand Down Expand Up @@ -283,7 +282,7 @@ impl PartialEq for Number {
return true;
}
a == b
},
}
(Number::F32(a), Number::I32(b)) => *a == *b as f32,
}
}
Expand All @@ -303,15 +302,9 @@ impl Ord for Number {
// See `total_cmp` method in f32
match (self, other) {
(Number::I32(a), Number::I32(b)) => a.cmp(b),
(Number::I32(a), Number::F32(b)) => {
(*a as f32).total_cmp(b)
},
(Number::F32(a), Number::F32(b)) => {
a.total_cmp(b)
},
(Number::F32(a), Number::I32(b)) => {
a.total_cmp(&(*b as f32))
},
(Number::I32(a), Number::F32(b)) => (*a as f32).total_cmp(b),
(Number::F32(a), Number::F32(b)) => a.total_cmp(b),
(Number::F32(a), Number::I32(b)) => a.total_cmp(&(*b as f32)),
}
}
}
}

0 comments on commit aa278fd

Please sign in to comment.