Skip to content

Commit

Permalink
refactor: abandon config rag_min_score_rerank (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 9, 2024
1 parent e5cc194 commit 84e9515
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 36 deletions.
4 changes: 3 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ rag_chunk_size: null # Specifies the chunk size
rag_chunk_overlap: null # Specifies the chunk overlap
rag_min_score_vector_search: 0 # Specifies the minimum relevance score for vector-based searching
rag_min_score_keyword_search: 0 # Specifies the minimum relevance score for keyword-based searching
rag_min_score_rerank: 0 # Specifies the minimum relevance score for reranking
# Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs
rag_template: |
Use the following context as your learned knowledge, inside <context></context> XML tags.
Expand Down Expand Up @@ -75,6 +74,9 @@ left_prompt:
right_prompt:
'{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'

# ---- misc ----
serve_addr: 127.0.0.1:8000 # Default serve listening address

# ---- clients ----
clients:
# All clients have the following configuration:
Expand Down
22 changes: 4 additions & 18 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pub use self::role::{Role, RoleLike, BUILTIN_ROLES, CODE_ROLE, EXPLAIN_SHELL_ROL
use self::session::Session;

use crate::client::{
create_client_config, init_client, list_chat_models, list_client_types, list_reranker_models,
ClientConfig, Model, OPENAI_COMPATIBLE_PLATFORMS,
create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig,
Model, OPENAI_COMPATIBLE_PLATFORMS,
};
use crate::function::{FunctionDeclaration, Functions, ToolResult};
use crate::rag::Rag;
Expand Down Expand Up @@ -117,7 +117,6 @@ pub struct Config {
pub rag_chunk_overlap: Option<usize>,
pub rag_min_score_vector_search: f32,
pub rag_min_score_keyword_search: f32,
pub rag_min_score_rerank: f32,
pub rag_template: Option<String>,

#[serde(default)]
Expand Down Expand Up @@ -185,7 +184,6 @@ impl Default for Config {
rag_chunk_overlap: None,
rag_min_score_vector_search: 0.0,
rag_min_score_keyword_search: 0.0,
rag_min_score_rerank: 0.0,
rag_template: None,

document_loaders: Default::default(),
Expand Down Expand Up @@ -1146,29 +1144,20 @@ impl Config {
abort_signal: AbortSignal,
) -> Result<String> {
let (reranker_model, top_k) = rag.get_config();
let (min_score_vector_search, min_score_keyword_search, rag_min_score_rerank) = {
let (min_score_vector_search, min_score_keyword_search) = {
let config = config.read();
(
config.rag_min_score_vector_search,
config.rag_min_score_keyword_search,
config.rag_min_score_rerank,
)
};
let rerank = match reranker_model {
Some(reranker_model_id) => {
let rerank_model = Model::retrieve_reranker(&config.read(), &reranker_model_id)?;
let rerank_client = init_client(config, Some(rerank_model))?;
Some((rerank_client, rag_min_score_rerank))
}
None => None,
};
let embeddings = rag
.search(
text,
top_k,
min_score_vector_search,
min_score_keyword_search,
rerank,
reranker_model.as_deref(),
abort_signal,
)
.await?;
Expand Down Expand Up @@ -1849,9 +1838,6 @@ impl Config {
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_keyword_search") {
self.rag_min_score_keyword_search = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_rerank") {
self.rag_min_score_rerank = v;
}
if let Some(v) = read_env_value::<String>("rag_template") {
self.rag_template = v;
}
Expand Down
40 changes: 23 additions & 17 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,12 @@ impl Rag {
top_k: usize,
min_score_vector_search: f32,
min_score_keyword_search: f32,
rerank: Option<(Box<dyn Client>, f32)>,
rerank_model: Option<&str>,
abort_signal: AbortSignal,
) -> Result<String> {
let spinner = create_spinner("Searching").await;
let ret = tokio::select! {
ret = self.hybird_search(text, top_k, min_score_vector_search, min_score_keyword_search, rerank) => {
ret = self.hybird_search(text, top_k, min_score_vector_search, min_score_keyword_search, rerank_model) => {
ret
}
_ = watch_abort_signal(abort_signal) => {
Expand Down Expand Up @@ -425,7 +425,7 @@ impl Rag {
top_k: usize,
min_score_vector_search: f32,
min_score_keyword_search: f32,
rerank: Option<(Box<dyn Client>, f32)>,
rerank_model: Option<&str>,
) -> Result<Vec<String>> {
let (vector_search_result, text_search_result) = tokio::join!(
self.vector_search(query, top_k, min_score_vector_search),
Expand All @@ -434,11 +434,14 @@ impl Rag {
let vector_search_ids = vector_search_result?;
let keyword_search_ids = text_search_result?;
debug!(
"vector_search_ids: {vector_search_ids:?}, keyword_search_ids: {keyword_search_ids:?}"
"vector_search_ids: {:?}, keyword_search_ids: {:?}",
pretty_document_ids(&vector_search_ids),
pretty_document_ids(&keyword_search_ids)
);
let ids = match rerank {
Some((client, min_score)) => {
let min_score = min_score as f64;
let ids = match rerank_model {
Some(model_id) => {
let model = Model::retrieve_reranker(&self.config.read(), model_id)?;
let client = init_client(&self.config, Some(model))?;
let ids: IndexSet<DocumentId> = [vector_search_ids, keyword_search_ids]
.concat()
.into_iter()
Expand All @@ -453,18 +456,12 @@ impl Rag {
}
let data = RerankData::new(query.to_string(), documents, top_k);
let list = client.rerank(data).await?;
let ids = list
let ids: Vec<_> = list
.into_iter()
.take(top_k)
.filter_map(|item| {
if item.relevance_score < min_score {
None
} else {
documents_ids.get(item.index).cloned()
}
})
.filter_map(|item| documents_ids.get(item.index).cloned())
.collect();
debug!("rerank_ids: {ids:?}");
debug!("rerank_ids: {:?}", pretty_document_ids(&ids));
ids
}
None => {
Expand All @@ -473,7 +470,7 @@ impl Rag {
vec![1.0, 1.0],
top_k,
);
debug!("rrf_ids: {ids:?}");
debug!("rrf_ids: {:?}", pretty_document_ids(&ids));
ids
}
};
Expand Down Expand Up @@ -713,6 +710,15 @@ pub fn split_document_id(value: DocumentId) -> (usize, usize) {
(high, low)
}

fn pretty_document_ids(ids: &[DocumentId]) -> Vec<String> {
ids.iter()
.map(|v| {
let (h, l) = split_document_id(*v);
format!("{h}-{l}")
})
.collect()
}

fn select_embedding_model(models: &[&Model]) -> Result<String> {
let models: Vec<_> = models
.iter()
Expand Down

0 comments on commit 84e9515

Please sign in to comment.