diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 1750776b8..35a66478e 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -7,13 +7,14 @@ use std::time::Duration; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat}; -use crate::message::Message; +use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; -use mcp_core::tool::Tool; +use mcp_core::tool::{Tool, ToolCall}; pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet"; pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic"; +pub const OPENROUTER_MODEL_PREFIX_DEEPSEEK: &str = "deepseek-r1"; // OpenRouter can run many models, we suggest the default pub const OPENROUTER_KNOWN_MODELS: &[&str] = &[OPENROUTER_DEFAULT_MODEL]; @@ -77,58 +78,44 @@ impl OpenRouterProvider { } /// Update the request when using anthropic model. -/// For anthropic model, we can enable prompt caching to save cost. Since openrouter is the OpenAI compatible -/// endpoint, we need to modify the open ai request to have anthropic cache control field. +/// For older anthropic models we enabled prompt caching, but newer ones (Claude-3) don't support it. fn update_request_for_anthropic(original_payload: &Value) -> Value { let mut payload = original_payload.clone(); + // Only add cache control for non-Claude-3 models + if !payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("") + .contains("claude-3") + { + if let Some(messages_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("messages")) + .and_then(|messages| messages.as_array_mut()) + {} + } + payload +} + +fn update_request_for_deepseek(original_payload: &Value) -> Value { + let mut payload = original_payload.clone(); + + // Extract tools before removing them from the payload + let tools = payload.get("tools").cloned(); + if let Some(messages_spec) = payload .as_object_mut() .and_then(|obj| obj.get_mut("messages")) .and_then(|messages| messages.as_array_mut()) - { - // Add "cache_control" to the last and second-to-last "user" messages. - // During each turn, we mark the final message with cache_control so the conversation can be - // incrementally cached. The second-to-last user message is also marked for caching with the - // cache_control parameter, so that this checkpoint can read from the previous cache. - let mut user_count = 0; - for message in messages_spec.iter_mut().rev() { - if message.get("role") == Some(&json!("user")) { - if let Some(content) = message.get_mut("content") { - if let Some(content_str) = content.as_str() { - *content = json!([{ - "type": "text", - "text": content_str, - "cache_control": { "type": "ephemeral" } - }]); - } - } - user_count += 1; - if user_count >= 2 { - break; - } - } - } + {} - // Update the system message to have cache_control field. - if let Some(system_message) = messages_spec - .iter_mut() - .find(|msg| msg.get("role") == Some(&json!("system"))) - { - if let Some(content) = system_message.get_mut("content") { - if let Some(content_str) = content.as_str() { - *system_message = json!({ - "role": "system", - "content": [{ - "type": "text", - "text": content_str, - "cache_control": { "type": "ephemeral" } - }] - }); - } - } - } + // Remove any tools/function calling capabilities from the request + if let Some(obj) = payload.as_object_mut() { + obj.remove("tools"); + obj.remove("tool_choice"); } + payload } @@ -146,6 +133,7 @@ fn create_request_based_on_model( &super::utils::ImageFormat::OpenAi, )?; + // Check for Anthropic models if model_config .model_name .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) @@ -153,6 +141,14 @@ fn create_request_based_on_model( payload = update_request_for_anthropic(&payload); } + // Check for DeepSeek models + if model_config + .model_name + .contains(OPENROUTER_MODEL_PREFIX_DEEPSEEK) + { + payload = update_request_for_deepseek(&payload); + } + Ok(payload) } @@ -201,8 +197,111 @@ impl Provider for OpenRouterProvider { // Make request let response = self.post(payload.clone()).await?; - // Parse response - let message = response_to_message(response.clone())?; + // Debug log the response structure + println!( + "OpenRouter response: {}", + serde_json::to_string_pretty(&response).unwrap_or_default() + ); + + // First try to parse as OpenAI format + let mut message = response_to_message(response.clone())?; + + // If no tool calls were found in OpenAI format, check for XML format + if !message.is_tool_call() { + if let Some(MessageContent::Text(text_content)) = message.content.first() { + let content = &text_content.text; + if let Some(calls_start) = content.find("") { + if let Some(calls_end) = content.find("") { + let calls_text = &content[calls_start..=calls_end + 15]; + + // Extract the invoke block + if let Some(invoke_start) = calls_text.find("") { + let invoke_text = + &calls_text[invoke_start..invoke_start + invoke_end + 9]; + + // Parse name and parameters + if let Some(name_start) = invoke_text.find("name=\"") { + if let Some(name_end) = invoke_text[name_start + 6..].find("\"") + { + let name = invoke_text + [name_start + 6..name_start + 6 + name_end] + .to_string(); + + // Build parameters map + let mut parameters = serde_json::Map::new(); + let mut param_pos = 0; + while let Some(param_start) = + invoke_text[param_pos..].find("") + { + let param_text = &invoke_text[param_pos + + param_start + ..param_pos + param_start + param_end + 11]; + + if let Some(param_name_start) = + param_text.find("name=\"") + { + if let Some(param_name_end) = param_text + [param_name_start + 6..] + .find("\"") + { + let param_name = ¶m_text + [param_name_start + 6 + ..param_name_start + + 6 + + param_name_end]; + + if let Some(value_start) = + param_text.find(">") + { + if let Some(value_end) = param_text + [value_start + 1..] + .find("<") + { + let param_value = ¶m_text + [value_start + 1 + ..value_start + + 1 + + value_end]; + parameters.insert( + param_name.to_string(), + Value::String( + param_value.to_string(), + ), + ); + } + } + } + } + param_pos += param_start + param_end + 11; + } else { + break; + } + } + + // Create tool request + message.content.clear(); + message.content.push(MessageContent::tool_request( + "1", + Ok(ToolCall { + name, + arguments: serde_json::to_value(parameters) + .unwrap_or_default(), + }), + )); + } + } + } + } + } + } + } + } + let usage = match get_usage(&response) { Ok(usage) => usage, Err(ProviderError::UsageError(e)) => {