diff --git a/config.example.yaml b/config.example.yaml index fa97583e..db618177 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -117,6 +117,7 @@ clients: api_base: http://localhost:8080/v1 api_key: xxx # Optional chat_endpoint: /chat/completions # Optional + embeddings_endpoint: /embeddings # Optional models: - name: llama3.1 max_input_tokens: 128000 @@ -181,9 +182,9 @@ clients: api_base: https://api.groq.com/openai/v1 # See https://github.com/jmorganca/ollama - - type: ollama - api_base: http://localhost:11434 - api_auth: Basic xxx # optional + - type: openai-compatible + name: ollama + api_base: http://localhost:11434/v1 models: - name: llama3.1 max_input_tokens: 128000 @@ -252,9 +253,10 @@ clients: secret_key: xxxx # See https://help.aliyun.com/zh/dashscope/ - - type: qianwen + - type: openai-compatible + name: qianwen api_key: sk-xxx - api_base: https://dashscope.aliyuncs.com/api/v1 # Optional + api_base: https://dashscope.aliyuncs.com/compatible-mode/v1 # See https://platform.moonshot.cn/docs/intro - type: openai-compatible diff --git a/src/client/mod.rs b/src/client/mod.rs index 9d6dfb0e..8ab7c783 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -25,7 +25,6 @@ register_client!( (gemini, "gemini", GeminiConfig, GeminiClient), (claude, "claude", ClaudeConfig, ClaudeClient), (cohere, "cohere", CohereConfig, CohereClient), - (ollama, "ollama", OllamaConfig, OllamaClient), ( azure_openai, "azure-openai", @@ -37,10 +36,9 @@ register_client!( (cloudflare, "cloudflare", CloudflareConfig, CloudflareClient), (replicate, "replicate", ReplicateConfig, ReplicateClient), (ernie, "ernie", ErnieConfig, ErnieClient), - (qianwen, "qianwen", QianwenConfig, QianwenClient), ); -pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 16] = [ +pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 18] = [ ("ai21", "https://api.ai21.com/studio/v1"), ("deepinfra", "https://api.deepinfra.com/v1/openai"), ("deepseek", "https://api.deepseek.com"), @@ -53,7 +51,12 @@ pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 16] = [ ("moonshot", "https://api.moonshot.cn/v1"), ("openrouter", "https://openrouter.ai/api/v1"), ("octoai", "https://text.octoai.run/v1"), + ("ollama", "http://localhost:11434/v1"), ("perplexity", "https://api.perplexity.ai"), + ( + "qianwen", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + ), ("together", "https://api.together.xyz/v1"), ("zhipuai", "https://open.bigmodel.cn/api/paas/v4"), ("voyageai", "https://api.voyageai.com/v1"), diff --git a/src/client/ollama.rs b/src/client/ollama.rs deleted file mode 100644 index 4c2f3445..00000000 --- a/src/client/ollama.rs +++ /dev/null @@ -1,291 +0,0 @@ -use super::*; - -use anyhow::{bail, Context, Result}; -use reqwest::RequestBuilder; -use serde::Deserialize; -use serde_json::{json, Value}; - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct OllamaConfig { - pub name: Option, - pub api_base: Option, - pub api_auth: Option, - #[serde(default)] - pub models: Vec, - pub patch: Option, - pub extra: Option, -} - -impl OllamaClient { - config_get_fn!(api_base, get_api_base); - config_get_fn!(api_auth, get_api_auth); - - pub const PROMPTS: [PromptAction<'static>; 4] = [ - ("api_base", "API Base:", true, PromptKind::String), - ("api_auth", "API Auth:", false, PromptKind::String), - ("models[].name", "Model Name:", true, PromptKind::String), - ( - "models[].max_input_tokens", - "Max Input Tokens:", - false, - PromptKind::Integer, - ), - ]; -} - -impl_client_trait!( - OllamaClient, - ( - prepare_chat_completions, - chat_completions, - chat_completions_streaming - ), - (prepare_embeddings, embeddings), - (noop_prepare_rerank, noop_rerank), -); - -fn prepare_chat_completions( - self_: &OllamaClient, - data: ChatCompletionsData, -) -> Result { - let api_base = self_.get_api_base()?; - let api_auth = self_.get_api_auth().ok(); - - let url = format!("{api_base}/api/chat"); - - let body = build_chat_completions_body(data, &self_.model)?; - - let mut request_data = RequestData::new(url, body); - - if let Some(api_auth) = api_auth { - request_data.header("Authorization", api_auth) - } - - Ok(request_data) -} - -fn prepare_embeddings(self_: &OllamaClient, data: EmbeddingsData) -> Result { - let api_base = self_.get_api_base()?; - let api_auth = self_.get_api_auth().ok(); - - let url = format!("{api_base}/api/embed"); - - let body = json!({ - "model": self_.model.name(), - "input": data.texts, - }); - - let mut request_data = RequestData::new(url, body); - - if let Some(api_auth) = api_auth { - request_data.header("Authorization", api_auth) - } - - Ok(request_data) -} - -async fn chat_completions( - builder: RequestBuilder, - _model: &Model, -) -> Result { - let res = builder.send().await?; - let status = res.status(); - let data = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - debug!("non-stream-data: {data}"); - - extract_chat_completions(&data) -} - -async fn chat_completions_streaming( - builder: RequestBuilder, - handler: &mut SseHandler, - _model: &Model, -) -> Result<()> { - let res = builder.send().await?; - let status = res.status(); - if !status.is_success() { - let data = res.json().await?; - catch_error(&data, status.as_u16())?; - } else { - let handle = |message: &str| -> Result<()> { - let data: Value = serde_json::from_str(message)?; - debug!("stream-data: {data}"); - - if data["done"].is_boolean() { - if let Some(text) = data["message"]["content"].as_str() { - handler.text(text)?; - } - } else { - bail!("Invalid response data: {data}") - } - - Ok(()) - }; - - json_stream(res.bytes_stream(), handle).await?; - } - - Ok(()) -} - -async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { - let res = builder.send().await?; - let status = res.status(); - let data = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - let res_body: EmbeddingsResBody = - serde_json::from_value(data).context("Invalid embeddings data")?; - Ok(res_body.embeddings) -} - -#[derive(Deserialize)] -struct EmbeddingsResBody { - embeddings: Vec>, -} - -fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { - let ChatCompletionsData { - messages, - temperature, - top_p, - functions, - stream, - } = data; - - let mut network_image_urls = vec![]; - - let messages: Vec = messages - .into_iter() - .flat_map(|message| { - let Message { role, content } = message; - match content { - MessageContent::Text(text) => vec![json!({ - "role": role, - "content": text, - })], - MessageContent::Array(list) => { - let mut content = vec![]; - let mut images = vec![]; - for item in list { - match item { - MessageContentPart::Text { text } => { - content.push(text); - } - MessageContentPart::ImageUrl { - image_url: ImageUrl { url }, - } => { - if let Some((_, data)) = url - .strip_prefix("data:") - .and_then(|v| v.split_once(";base64,")) - { - images.push(data.to_string()); - } else { - network_image_urls.push(url.clone()); - } - } - } - } - let content = content.join("\n\n"); - vec![json!({ "role": role, "content": content, "images": images })] - } - MessageContent::ToolResults((tool_results, text)) => { - let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| { - json!({ - "function": { - "name": tool_result.call.name, - "arguments": tool_result.call.arguments, - }, - }) - }).collect(); - let mut messages = vec![ - json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls }) - ]; - for tool_result in tool_results { - messages.push( - json!({ - "role": "tool", - "content": tool_result.output.to_string(), - }) - ); - } - messages - }, - } - }) - .collect(); - - if !network_image_urls.is_empty() { - bail!( - "The model does not support network images: {:?}", - network_image_urls - ); - } - - let mut body = json!({ - "model": &model.name(), - "messages": messages, - "stream": stream, - "options": {}, - }); - - if let Some(v) = model.max_tokens_param() { - body["options"]["num_predict"] = v.into(); - } - if let Some(v) = temperature { - body["options"]["temperature"] = v.into(); - } - if let Some(v) = top_p { - body["options"]["top_p"] = v.into(); - } - if let Some(functions) = functions { - body["tools"] = functions - .iter() - .map(|v| { - json!({ - "type": "function", - "function": v, - }) - }) - .collect(); - } - - Ok(body) -} - -fn extract_chat_completions(data: &Value) -> Result { - let text = data["message"]["content"].as_str().unwrap_or_default(); - - let mut tool_calls = vec![]; - if let Some(calls) = data["message"]["tool_calls"].as_array() { - tool_calls = calls - .iter() - .filter_map(|call| { - if let (Some(name), arguments) = ( - call["function"]["name"].as_str(), - call["function"]["arguments"].clone(), - ) { - Some(ToolCall::new(name.to_string(), arguments, None)) - } else { - None - } - }) - .collect() - }; - - if text.is_empty() && tool_calls.is_empty() { - bail!("Invalid response data: {data}"); - } - let output = ChatCompletionsOutput { - text: text.to_string(), - tool_calls, - id: None, - input_tokens: data["prompt_eval_count"].as_u64(), - output_tokens: data["eval_count"].as_u64(), - }; - Ok(output) -} diff --git a/src/client/openai.rs b/src/client/openai.rs index ed577bc1..e47ad4c0 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -111,7 +111,7 @@ pub async fn openai_chat_completions_streaming( handler.tool_call(ToolCall::new( function_name.clone(), json!(function_arguments), - Some(function_id.clone()), + normalize_function_id(&function_id), ))?; } return Ok(true); @@ -131,7 +131,7 @@ pub async fn openai_chat_completions_streaming( handler.tool_call(ToolCall::new( function_name.clone(), json!(function_arguments), - Some(function_id.clone()), + normalize_function_id(&function_id), ))?; } function_name.clear(); @@ -140,7 +140,11 @@ pub async fn openai_chat_completions_streaming( function_index = index; } if let Some(name) = function.get("name").and_then(|v| v.as_str()) { - function_name = name.to_string(); + if name.starts_with(&function_name) { + function_name = name.to_string(); + } else { + function_name.push_str(name); + } } if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) { function_arguments.push_str(arguments); @@ -196,30 +200,57 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod let Message { role, content } = message; match content { MessageContent::ToolResults((tool_results, text)) => { - let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| { - json!({ - "id": tool_result.call.id, - "type": "function", - "function": { - "name": tool_result.call.name, - "arguments": tool_result.call.arguments, - }, - }) - }).collect(); - let text = if text.is_empty() { Value::Null } else { text.into() }; - let mut messages = vec![ - json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls }) - ]; - for tool_result in tool_results { - messages.push( + if let Some(true) = tool_results.first().map(|v| v.call.id.is_some()) { + let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| { json!({ - "role": "tool", - "content": tool_result.output.to_string(), - "tool_call_id": tool_result.call.id, + "id": tool_result.call.id, + "type": "function", + "function": { + "name": tool_result.call.name, + "arguments": tool_result.call.arguments, + }, }) - ); + }).collect(); + let text = if text.is_empty() { Value::Null } else { text.into() }; + let mut messages = vec![ + json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls }) + ]; + for tool_result in tool_results { + messages.push( + json!({ + "role": "tool", + "content": tool_result.output.to_string(), + "tool_call_id": tool_result.call.id, + }) + ); + } + messages + } else { + tool_results.into_iter().flat_map(|tool_result| { + vec![ + json!({ + "role": MessageRole::Assistant, + "content": null, + "tool_calls": [ + { + "id": tool_result.call.id, + "type": "function", + "function": { + "name": tool_result.call.name, + "arguments": tool_result.call.arguments, + }, + } + ] + }), + json!({ + "role": "tool", + "content": tool_result.output.to_string(), + "tool_call_id": tool_result.call.id, + }) + ] + + }).collect() } - messages }, _ => vec![json!({ "role": role, "content": content })] } @@ -303,3 +334,11 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result Option { + if value.is_empty() { + None + } else { + Some(value.to_string()) + } +} diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs index 2bde884e..da261b3f 100644 --- a/src/client/openai_compatible.rs +++ b/src/client/openai_compatible.rs @@ -12,6 +12,7 @@ pub struct OpenAICompatibleConfig { pub api_base: Option, pub api_key: Option, pub chat_endpoint: Option, + pub embeddings_endpoint: Option, #[serde(default)] pub models: Vec, pub patch: Option, @@ -82,7 +83,18 @@ fn prepare_embeddings(self_: &OpenAICompatibleClient, data: EmbeddingsData) -> R let api_key = self_.get_api_key().ok(); let api_base = get_api_base_ext(self_)?; - let url = format!("{api_base}/embeddings"); + let embeddings_endpoint = match self_.config.embeddings_endpoint.clone() { + Some(v) => { + if v.starts_with('/') { + v + } else { + format!("/{}", v) + } + } + None => "/embeddings".into(), + }; + + let url = format!("{api_base}{embeddings_endpoint}"); let body = openai_build_embeddings_body(data, &self_.model); diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs deleted file mode 100644 index 38534d86..00000000 --- a/src/client/qianwen.rs +++ /dev/null @@ -1,509 +0,0 @@ -use super::*; - -use crate::utils::{base64_decode, sha256}; - -use anyhow::{anyhow, bail, Context, Result}; -use reqwest::{ - multipart::{Form, Part}, - Client as ReqwestClient, RequestBuilder, -}; -use serde::Deserialize; -use serde_json::{json, Value}; -use std::borrow::BorrowMut; - -const API_BASE: &str = "https://dashscope.aliyuncs.com/api/v1"; - -const CHAT_COMPLETIONS_ENDPOINT: &str = "/services/aigc/text-generation/generation"; - -const CHAT_COMPLETIONS_VL_ENDPOINT: &str = "/services/aigc/multimodal-generation/generation"; - -const EMBEDDINGS_ENDPOINT: &str = "/services/embeddings/text-embedding/text-embedding"; - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct QianwenConfig { - pub name: Option, - pub api_key: Option, - pub api_base: Option, - #[serde(default)] - pub models: Vec, - pub patch: Option, - pub extra: Option, -} - -impl QianwenClient { - config_get_fn!(api_key, get_api_key); - config_get_fn!(api_base, get_api_base); - - pub const PROMPTS: [PromptAction<'static>; 1] = - [("api_key", "API Key:", true, PromptKind::String)]; -} - -#[async_trait::async_trait] -impl Client for QianwenClient { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &ReqwestClient, - mut data: ChatCompletionsData, - ) -> Result { - let api_key = self.get_api_key()?; - patch_messages(self.model.name(), &api_key, &mut data.messages).await?; - let request_data = prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions(builder, &self.model).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut SseHandler, - mut data: ChatCompletionsData, - ) -> Result<()> { - let api_key = self.get_api_key()?; - patch_messages(self.model.name(), &api_key, &mut data.messages).await?; - let request_data = prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions_streaming(builder, handler, &self.model).await - } - - async fn embeddings_inner( - &self, - client: &ReqwestClient, - data: EmbeddingsData, - ) -> Result>> { - let request_data = prepare_embeddings(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::Embeddings); - embeddings(builder, &self.model).await - } -} - -fn prepare_chat_completions( - self_: &QianwenClient, - data: ChatCompletionsData, -) -> Result { - let api_key = self_.get_api_key()?; - let api_base = self_ - .get_api_base() - .unwrap_or_else(|_| API_BASE.to_string()); - - let stream = data.stream; - - let url = match self_.model().supports_vision() { - true => format!( - "{}{CHAT_COMPLETIONS_VL_ENDPOINT}", - api_base.trim_end_matches('/'), - ), - false => format!( - "{}{CHAT_COMPLETIONS_ENDPOINT}", - api_base.trim_end_matches('/'), - ), - }; - - let (body, has_upload) = build_chat_completions_body(data, &self_.model)?; - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - if stream { - request_data.header("X-DashScope-SSE", "enable"); - } - if has_upload { - request_data.header("X-DashScope-OssResourceResolve", "enable"); - } - - Ok(request_data) -} - -fn prepare_embeddings(self_: &QianwenClient, data: EmbeddingsData) -> Result { - let api_key = self_.get_api_key()?; - let api_base = self_ - .get_api_base() - .unwrap_or_else(|_| API_BASE.to_string()); - - let url = format!("{}{EMBEDDINGS_ENDPOINT}", api_base.trim_end_matches('/'),); - - let text_type = match data.query { - true => "query", - false => "document", - }; - - let body = json!({ - "model": self_.model.name(), - "input": { - "texts": data.texts, - }, - "parameters": { - "text_type": text_type, - } - }); - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - Ok(request_data) -} - -async fn chat_completions(builder: RequestBuilder, model: &Model) -> Result { - let data: Value = builder.send().await?.json().await?; - maybe_catch_error(&data)?; - - debug!("non-stream-data: {data}"); - extract_chat_completions_text(&data, model) -} - -async fn chat_completions_streaming( - builder: RequestBuilder, - handler: &mut SseHandler, - model: &Model, -) -> Result<()> { - let model_name = model.name(); - let mut prev_text = String::new(); - let handle = |message: SseMmessage| -> Result { - let data: Value = serde_json::from_str(&message.data)?; - maybe_catch_error(&data)?; - debug!("stream-data: {data}"); - if model_name == "qwen-long" { - if let Some(text) = data["output"]["choices"][0]["message"]["content"].as_str() { - handler.text(text)?; - } - } else if model.supports_vision() { - if let Some(text) = - data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() - { - handler.text(text)?; - } - } else if let Some(text) = data["output"]["text"].as_str() { - if let Some(pos) = text.rfind("✿FUNCTION") { - if pos > prev_text.len() { - let delta_text = &text[prev_text.len()..pos]; - if delta_text != ": \n" { - handler.text(delta_text)?; - } - } - prev_text = text.to_string(); - if let Some((name, arguments)) = parse_tool_call(&text[pos..]) { - let arguments: Value = arguments - .parse() - .with_context(|| format!("Invalid function call {name} {arguments}"))?; - handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?; - } - } else { - let mut delta_text = &text[prev_text.len()..]; - if prev_text.is_empty() && delta_text.starts_with(": ") { - delta_text = &delta_text[2..]; - } - prev_text = text.to_string(); - handler.text(delta_text)?; - } - } - Ok(false) - }; - - sse_stream(builder, handle).await -} - -fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<(Value, bool)> { - let ChatCompletionsData { - messages, - temperature, - top_p, - functions, - stream, - } = data; - - let mut has_upload = false; - let input = if model.supports_vision() { - let messages: Vec = messages - .into_iter() - .map(|message| { - let role = message.role; - let content = match message.content { - MessageContent::Text(text) => vec![json!({"text": text})], - MessageContent::Array(list) => list - .into_iter() - .map(|item| match item { - MessageContentPart::Text { text } => json!({"text": text}), - MessageContentPart::ImageUrl { - image_url: ImageUrl { url }, - } => { - if url.starts_with("oss:") { - has_upload = true; - } - json!({"image": url}) - } - }) - .collect(), - MessageContent::ToolResults(_) => { - vec![] - } - }; - json!({ "role": role, "content": content }) - }) - .collect(); - - json!({ - "messages": messages, - }) - } else { - let messages: Vec = - messages - .into_iter() - .flat_map(|message| { - let role = message.role; - match message.content { - MessageContent::Text(text) => vec![json!({ "role": role, "content": text })], - MessageContent::Array(list) => { - let parts: Vec<_> = list - .into_iter() - .map(|item| match item { - MessageContentPart::Text { text } => json!({"text": text}), - MessageContentPart::ImageUrl { - image_url: ImageUrl { url }, - } => { - if url.starts_with("oss:") { - has_upload = true; - } - json!({"image": url}) - } - }) - .collect(); - vec![json!({ "role": role, "content": parts })] - } - MessageContent::ToolResults((tool_results, _)) => { - tool_results.into_iter().flat_map(|tool_result| vec![ - json!({ - "role": MessageRole::Assistant, - "content": "", - "tool_calls": vec![ - json!({ - "type": "function", - "function": { - "name": tool_result.call.name, - "arguments": tool_result.call.arguments.to_string(), - }, - }) - ], - }), - json!({ - "role": "tool", - "content": tool_result.output.to_string(), - "name": tool_result.call.name, - }), - ]).collect() - } - } - }) - .collect(); - json!({ - "messages": messages, - }) - }; - - let mut parameters = json!({}); - - if stream && (model.name() == "qwen-long" || model.supports_vision()) { - parameters["incremental_output"] = true.into(); - } - - if let Some(v) = model.max_tokens_param() { - parameters["max_tokens"] = v.into(); - } - if let Some(v) = temperature { - parameters["temperature"] = v.into(); - } - if let Some(v) = top_p { - parameters["top_p"] = v.into(); - } - - if let Some(functions) = functions { - parameters["tools"] = functions - .iter() - .map(|v| { - json!({ - "type": "function", - "function": v, - }) - }) - .collect(); - } - - let body = json!({ - "model": &model.name(), - "input": input, - "parameters": parameters - }); - - Ok((body, has_upload)) -} - -async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { - let data: Value = builder.send().await?.json().await?; - maybe_catch_error(&data)?; - let res_body: EmbeddingsResBody = - serde_json::from_value(data).context("Invalid embeddings data")?; - let output = res_body - .output - .embeddings - .into_iter() - .map(|v| v.embedding) - .collect(); - Ok(output) -} - -#[derive(Deserialize)] -struct EmbeddingsResBody { - output: EmbeddingsResBodyOutput, -} - -#[derive(Deserialize)] -struct EmbeddingsResBodyOutput { - embeddings: Vec, -} - -#[derive(Deserialize)] -struct EmbeddingsResBodyOutputEmbedding { - embedding: Vec, -} - -fn extract_chat_completions_text(data: &Value, model: &Model) -> Result { - let err = || anyhow!("Invalid response data: {data}"); - let mut tool_calls = vec![]; - let text = if model.name() == "qwen-long" { - data["output"]["choices"][0]["message"]["content"] - .as_str() - .ok_or_else(err)? - } else if model.supports_vision() { - data["output"]["choices"][0]["message"]["content"][0]["text"] - .as_str() - .ok_or_else(err)? - } else { - let text = data["output"]["text"].as_str().ok_or_else(err)?; - match parse_tool_call(text) { - Some((name, arguments)) => { - let arguments: Value = arguments - .parse() - .with_context(|| format!("Invalid function call {name} {arguments}"))?; - tool_calls.push(ToolCall::new(name.to_string(), arguments, None)); - "" - } - None => text, - } - }; - let output = ChatCompletionsOutput { - text: text.to_string(), - tool_calls, - id: data["request_id"].as_str().map(|v| v.to_string()), - input_tokens: data["usage"]["input_tokens"].as_u64(), - output_tokens: data["usage"]["output_tokens"].as_u64(), - }; - - Ok(output) -} - -/// Patch messages, upload embedded images to oss -async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec) -> Result<()> { - for message in messages { - if let MessageContent::Array(list) = message.content.borrow_mut() { - for item in list { - if let MessageContentPart::ImageUrl { - image_url: ImageUrl { url }, - } = item - { - if url.starts_with("data:") { - *url = upload(model, api_key, url) - .await - .with_context(|| "Failed to upload embedded image to oss")?; - } - } - } - } - } - Ok(()) -} - -#[derive(Debug, Deserialize)] -struct Policy { - data: PolicyData, -} - -#[derive(Debug, Deserialize)] -struct PolicyData { - policy: String, - signature: String, - upload_dir: String, - upload_host: String, - oss_access_key_id: String, - x_oss_object_acl: String, - x_oss_forbid_overwrite: String, -} - -/// Upload image to dashscope -async fn upload(model: &str, api_key: &str, url: &str) -> Result { - let (mime_type, data) = url - .strip_prefix("data:") - .and_then(|v| v.split_once(";base64,")) - .ok_or_else(|| anyhow!("Invalid image url"))?; - let mut name = sha256(data); - if let Some(ext) = mime_type.strip_prefix("image/") { - name.push('.'); - name.push_str(ext); - } - let data = base64_decode(data)?; - - let client = reqwest::Client::new(); - let policy: Policy = client - .get(format!( - "https://dashscope.aliyuncs.com/api/v1/uploads?action=getPolicy&model={model}" - )) - .header("Authorization", format!("Bearer {api_key}")) - .send() - .await? - .json() - .await?; - let PolicyData { - policy, - signature, - upload_dir, - upload_host, - oss_access_key_id, - x_oss_object_acl, - x_oss_forbid_overwrite, - .. - } = policy.data; - - let key = format!("{upload_dir}/{name}"); - let file = Part::bytes(data).file_name(name).mime_str(mime_type)?; - let form = Form::new() - .text("OSSAccessKeyId", oss_access_key_id) - .text("Signature", signature) - .text("policy", policy) - .text("key", key.clone()) - .text("x-oss-object-acl", x_oss_object_acl) - .text("x-oss-forbid-overwrite", x_oss_forbid_overwrite) - .text("success_action_status", "200") - .text("x-oss-content-type", mime_type.to_string()) - .part("file", file); - - let res = client.post(upload_host).multipart(form).send().await?; - - let status = res.status(); - if !status.is_success() { - let text = res.text().await?; - bail!("Invalid response data: {text} (status: {status})") - } - Ok(format!("oss://{key}")) -} - -fn parse_tool_call(text: &str) -> Option<(&str, &str)> { - let function_symbol = "✿FUNCTION✿: "; - let result_symbol = "\n✿RESULT✿: "; - let args_symbol = "\n✿ARGS✿: "; - let start = text.find(function_symbol)? + function_symbol.len(); - let text = &text[start..]; - let end = text.find(result_symbol)?; - let text = &text[..end]; - text.split_once(args_symbol) -}