Skip to content

Commit

Permalink
feat: add AICHAT_EMBEDDINGS_RETRY_LIMIT (#882)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 24, 2024
1 parent 3f2b946 commit 4686c47
Showing 1 changed file with 9 additions and 26 deletions.
35 changes: 9 additions & 26 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@ use parking_lot::RwLock;
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{collections::HashMap, fmt::Debug, fs, path::Path, time::Duration};
use std::{collections::HashMap, env, fmt::Debug, fs, path::Path, time::Duration};
use tokio::time::sleep;

const EMBEDDING_RETRY_LIMIT: usize = 3;
const RERANK_RETRY_LIMIT: usize = 2;

pub struct Rag {
config: GlobalConfig,
name: String,
Expand Down Expand Up @@ -487,23 +484,7 @@ impl Rag {
}
}
let data = RerankData::new(query.to_string(), documents, top_k);
let mut retry = 0;
let list = loop {
retry += 1;
match client.rerank(&data).await {
Ok(result) => break result,
Err(e) if retry < RERANK_RETRY_LIMIT => {
debug!("retry {} failed: {}", retry, e);
sleep(Duration::from_secs(retry as _)).await;
continue;
}
Err(e) => {
return Err(e).with_context(|| {
format!("Failed to rerank after {RERANK_RETRY_LIMIT} attempts")
})?
}
}
};
let list = client.rerank(&data).await.context("Failed to rerank")?;
let ids: Vec<_> = list
.into_iter()
.take(top_k)
Expand Down Expand Up @@ -598,6 +579,10 @@ impl Rag {
let mut output = vec![];
let batch_chunks = texts.chunks(batch_size.max(1));
let batch_chunks_len = batch_chunks.len();
let retry_limit = env::var(get_env_name("embeddings_retry_limit"))
.ok()
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(2);
for (index, texts) in batch_chunks.enumerate() {
progress(
&spinner,
Expand All @@ -612,16 +597,14 @@ impl Rag {
retry += 1;
match embedding_client.embeddings(&chunk_data).await {
Ok(v) => break v,
Err(e) if retry < EMBEDDING_RETRY_LIMIT => {
Err(e) if retry < retry_limit => {
debug!("retry {} failed: {}", retry, e);
sleep(Duration::from_secs(retry as _)).await;
sleep(Duration::from_secs(2u64.pow(retry - 1))).await;
continue;
}
Err(e) => {
return Err(e).with_context(|| {
format!(
"Failed to create embedding after {EMBEDDING_RETRY_LIMIT} attempts"
)
format!("Failed to create embedding after {retry_limit} attempts")
})?
}
}
Expand Down

0 comments on commit 4686c47

Please sign in to comment.