Skip to content

Commit

Permalink
feat: support RAG-scoped rag_top_k and rag_reranker_model (#847)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 8, 2024
1 parent 5adaa86 commit 89554e0
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 51 deletions.
104 changes: 71 additions & 33 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,10 @@ impl Config {
.wrap
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let (rag_reranker_model, rag_top_k) = match self.rag.as_ref() {
Some(rag) => rag.get_config(),
None => (self.rag_reranker_model.clone(), self.rag_top_k),
};
let role = self.extract_role();
let mut items = vec![
("model", role.model().id()),
Expand All @@ -535,9 +539,9 @@ impl Config {
("use_tools", format_option_value(&role.use_tools())),
(
"rag_reranker_model",
format_option_value(&self.rag_reranker_model),
format_option_value(&rag_reranker_model),
),
("rag_top_k", self.rag_top_k.to_string()),
("rag_top_k", rag_top_k.to_string()),
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("config_file", display_path(&Self::config_file()?)),
Expand All @@ -559,7 +563,7 @@ impl Config {
Ok(output)
}

pub fn update(&mut self, data: &str) -> Result<()> {
pub fn update(config: &GlobalConfig, data: &str) -> Result<()> {
let parts: Vec<&str> = data.split_whitespace().collect();
if parts.len() != 2 {
bail!("Usage: .set <key> <value>. If value is null, unset key.");
Expand All @@ -569,62 +573,58 @@ impl Config {
match key {
"max_output_tokens" => {
let value = parse_value(value)?;
self.set_max_output_tokens(value);
config.write().set_max_output_tokens(value);
}
"temperature" => {
let value = parse_value(value)?;
self.set_temperature(value);
config.write().set_temperature(value);
}
"top_p" => {
let value = parse_value(value)?;
self.set_top_p(value);
config.write().set_top_p(value);
}
"dry_run" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.dry_run = value;
config.write().dry_run = value;
}
"stream" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.stream = value;
config.write().stream = value;
}
"save" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.save = value;
config.write().save = value;
}
"rag_reranker_model" => {
self.rag_reranker_model = if value == "null" {
None
} else {
Some(value.to_string())
}
let value = parse_value(value)?;
Self::set_rag_reranker_model(config, value)?;
}
"rag_top_k" => {
if let Some(value) = parse_value(value)? {
self.rag_top_k = value;
}
let value = value.parse().with_context(|| "Invalid value")?;
Self::set_rag_top_k(config, value)?;
}
"function_calling" => {
let value = value.parse().with_context(|| "Invalid value")?;
if value && self.functions.is_empty() {
if value && config.write().functions.is_empty() {
bail!("Function calling cannot be enabled because no functions are installed.")
}
self.function_calling = value;
config.write().function_calling = value;
}
"use_tools" => {
let value = parse_value(value)?;
self.set_use_tools(value);
config.write().set_use_tools(value);
}
"save_session" => {
let value = parse_value(value)?;
self.set_save_session(value);
config.write().set_save_session(value);
}
"compress_threshold" => {
let value = parse_value(value)?;
self.set_compress_threshold(value);
config.write().set_compress_threshold(value);
}
"highlight" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.highlight = value;
config.write().highlight = value;
}
_ => bail!("Unknown key `{key}`"),
}
Expand Down Expand Up @@ -668,6 +668,33 @@ impl Config {
}
}

pub fn set_rag_reranker_model(config: &GlobalConfig, value: Option<String>) -> Result<()> {
if let Some(id) = &value {
Model::retrieve_reranker(&config.read(), id)?;
}
let has_rag = config.read().rag.is_some();
match has_rag {
true => update_rag(config, |rag| {
rag.set_reranker_model(value)?;
Ok(())
})?,
false => config.write().rag_reranker_model = value,
}
Ok(())
}

pub fn set_rag_top_k(config: &GlobalConfig, value: usize) -> Result<()> {
let has_rag = config.read().rag.is_some();
match has_rag {
true => update_rag(config, |rag| {
rag.set_top_k(value)?;
Ok(())
})?,
false => config.write().rag_top_k = value,
}
Ok(())
}

pub fn set_wrap(&mut self, value: &str) -> Result<()> {
if value == "no" {
self.wrap = None;
Expand Down Expand Up @@ -1090,13 +1117,11 @@ impl Config {
}

pub async fn rebuild_rag(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> {
let rag_name = match config.read().rag.clone() {
Some(v) => v.name().to_string(),
let mut rag = match config.read().rag.clone() {
Some(v) => v.as_ref().clone(),
None => bail!("No RAG"),
};
let rag_path = config.read().rag_file(&rag_name)?;
let mut rag = Rag::load(config, &rag_name, &rag_path)?;
rag.rebuild(config, &rag_path, abort_signal).await?;
rag.rebuild(config, abort_signal).await?;
config.write().rag = Some(Arc::new(rag));
Ok(())
}
Expand All @@ -1120,20 +1145,20 @@ impl Config {
text: &str,
abort_signal: AbortSignal,
) -> Result<String> {
let (top_k, min_score_vector_search, min_score_keyword_search) = {
let (reranker_model, top_k) = rag.get_config();
let (min_score_vector_search, min_score_keyword_search, rag_min_score_rerank) = {
let config = config.read();
(
config.rag_top_k,
config.rag_min_score_vector_search,
config.rag_min_score_keyword_search,
config.rag_min_score_rerank,
)
};
let rerank = match config.read().rag_reranker_model.clone() {
let rerank = match reranker_model {
Some(reranker_model_id) => {
let min_score = config.read().rag_min_score_rerank;
let rerank_model = Model::retrieve_reranker(&config.read(), &reranker_model_id)?;
let rerank_client = init_client(config, Some(rerank_model))?;
Some((rerank_client, min_score))
Some((rerank_client, rag_min_score_rerank))
}
None => None,
};
Expand Down Expand Up @@ -2056,3 +2081,16 @@ fn complete_option_bool(value: Option<bool>) -> Vec<String> {
None => vec!["true".to_string(), "false".to_string()],
}
}

fn update_rag<F>(config: &GlobalConfig, f: F) -> Result<()>
where
F: FnOnce(&mut Rag) -> Result<()>,
{
let mut rag = match config.read().rag.clone() {
Some(v) => v.as_ref().clone(),
None => bail!("No RAG"),
};
f(&mut rag)?;
config.write().rag = Some(Arc::new(rag));
Ok(())
}
Loading

0 comments on commit 89554e0

Please sign in to comment.