From 51566431bc166f5b54ccdc8c5618e831a4161143 Mon Sep 17 00:00:00 2001 From: sigoden Date: Mon, 9 Sep 2024 19:50:31 +0800 Subject: [PATCH] feat: proxy rerank api --- src/serve.rs | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/src/serve.rs b/src/serve.rs index 9e157b60..ab8e654d 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -56,6 +56,7 @@ pub async fn run(config: GlobalConfig, addr: Option) -> Result<()> { let stop_server = server.run(listener).await?; println!("Chat Completions API: http://{addr}/v1/chat/completions"); println!("Embeddings API: http://{addr}/v1/embeddings"); + println!("Rerank API: http://{addr}/v1/rerank"); println!("LLM Playground: http://{addr}/playground"); println!("LLM Arena: http://{addr}/arena?num=2"); shutdown_signal().await; @@ -158,6 +159,8 @@ impl Server { self.chat_completions(req).await } else if path == "/v1/embeddings" { self.embeddings(req).await + } else if path == "/v1/rerank" { + self.rerank(req).await } else if path == "/v1/models" { self.list_models() } else if path == "/v1/roles" { @@ -498,6 +501,57 @@ impl Server { .body(Full::new(Bytes::from(output.to_string())).boxed())?; Ok(res) } + + async fn rerank(&self, req: hyper::Request) -> Result { + let req_body = req.collect().await?.to_bytes(); + let req_body: Value = serde_json::from_slice(&req_body) + .map_err(|err| anyhow!("Invalid request json, {err}"))?; + + debug!("rerank request: {req_body}"); + let req_body = serde_json::from_value(req_body) + .map_err(|err| anyhow!("Invalid request body, {err}"))?; + + let RerankReqBody { + model: reranker_model_id, + documents, + query, + top_n, + } = req_body; + + let top_n = top_n.unwrap_or(documents.len()); + + let config = Arc::new(RwLock::new(self.config.clone())); + + let reranker_model = Model::retrieve_embedding(&config.read(), &reranker_model_id)?; + + let client = init_client(&config, Some(reranker_model))?; + let data = client + .rerank(RerankData { + query, + documents: documents.clone(), + top_n, + }) + .await?; + + let results: Vec<_> = data + .into_iter() + .map(|v| { + json!({ + "index": v.index, + "relevance_score": v.relevance_score, + "document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(), + }) + }) + .collect(); + let output = json!({ + "id": uuid::Uuid::new_v4().to_string(), + "results": results, + }); + let res = Response::builder() + .header("Content-Type", "application/json") + .body(Full::new(Bytes::from(output.to_string())).boxed())?; + Ok(res) + } } #[derive(Debug, Deserialize)] @@ -520,8 +574,8 @@ struct ChatCompletionsReqBody { #[derive(Debug, Deserialize)] struct EmbeddingsReqBody { - pub input: EmbeddingsReqBodyInput, - pub model: String, + input: EmbeddingsReqBodyInput, + model: String, } #[derive(Debug, Deserialize)] @@ -531,6 +585,14 @@ enum EmbeddingsReqBodyInput { Multiple(Vec), } +#[derive(Debug, Deserialize)] +struct RerankReqBody { + documents: Vec, + query: String, + model: String, + top_n: Option, +} + #[derive(Debug)] enum ResEvent { First(Option),