Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: deepseek r1 alternative tool calling format #975

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 146 additions & 47 deletions crates/goose/src/providers/openrouter.rs
Original file line number Diff line number Diff line change
@@ -7,13 +7,14 @@ use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat};
use crate::message::Message;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use mcp_core::tool::Tool;
use mcp_core::tool::{Tool, ToolCall};

pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet";
pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic";
pub const OPENROUTER_MODEL_PREFIX_DEEPSEEK: &str = "deepseek-r1";

// OpenRouter can run many models, we suggest the default
pub const OPENROUTER_KNOWN_MODELS: &[&str] = &[OPENROUTER_DEFAULT_MODEL];
@@ -77,58 +78,44 @@ impl OpenRouterProvider {
}

/// Update the request when using anthropic model.
/// For anthropic model, we can enable prompt caching to save cost. Since openrouter is the OpenAI compatible
/// endpoint, we need to modify the open ai request to have anthropic cache control field.
/// For older anthropic models we enabled prompt caching, but newer ones (Claude-3) don't support it.
fn update_request_for_anthropic(original_payload: &Value) -> Value {
let mut payload = original_payload.clone();

// Only add cache control for non-Claude-3 models
if !payload
.get("model")
.and_then(|m| m.as_str())
.unwrap_or("")
.contains("claude-3")
{
if let Some(messages_spec) = payload
.as_object_mut()
.and_then(|obj| obj.get_mut("messages"))
.and_then(|messages| messages.as_array_mut())
{}
}
payload
}

fn update_request_for_deepseek(original_payload: &Value) -> Value {
let mut payload = original_payload.clone();

// Extract tools before removing them from the payload
let tools = payload.get("tools").cloned();

if let Some(messages_spec) = payload
.as_object_mut()
.and_then(|obj| obj.get_mut("messages"))
.and_then(|messages| messages.as_array_mut())
{
// Add "cache_control" to the last and second-to-last "user" messages.
// During each turn, we mark the final message with cache_control so the conversation can be
// incrementally cached. The second-to-last user message is also marked for caching with the
// cache_control parameter, so that this checkpoint can read from the previous cache.
let mut user_count = 0;
for message in messages_spec.iter_mut().rev() {
if message.get("role") == Some(&json!("user")) {
if let Some(content) = message.get_mut("content") {
if let Some(content_str) = content.as_str() {
*content = json!([{
"type": "text",
"text": content_str,
"cache_control": { "type": "ephemeral" }
}]);
}
}
user_count += 1;
if user_count >= 2 {
break;
}
}
}
{}

// Update the system message to have cache_control field.
if let Some(system_message) = messages_spec
.iter_mut()
.find(|msg| msg.get("role") == Some(&json!("system")))
{
if let Some(content) = system_message.get_mut("content") {
if let Some(content_str) = content.as_str() {
*system_message = json!({
"role": "system",
"content": [{
"type": "text",
"text": content_str,
"cache_control": { "type": "ephemeral" }
}]
});
}
}
}
// Remove any tools/function calling capabilities from the request
if let Some(obj) = payload.as_object_mut() {
obj.remove("tools");
obj.remove("tool_choice");
}

payload
}

@@ -146,13 +133,22 @@ fn create_request_based_on_model(
&super::utils::ImageFormat::OpenAi,
)?;

// Check for Anthropic models
if model_config
.model_name
.starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
{
payload = update_request_for_anthropic(&payload);
}

// Check for DeepSeek models
if model_config
.model_name
.contains(OPENROUTER_MODEL_PREFIX_DEEPSEEK)
{
payload = update_request_for_deepseek(&payload);
}

Ok(payload)
}

@@ -201,8 +197,111 @@ impl Provider for OpenRouterProvider {
// Make request
let response = self.post(payload.clone()).await?;

// Parse response
let message = response_to_message(response.clone())?;
// Debug log the response structure
println!(
"OpenRouter response: {}",
serde_json::to_string_pretty(&response).unwrap_or_default()
);

// First try to parse as OpenAI format
let mut message = response_to_message(response.clone())?;

// If no tool calls were found in OpenAI format, check for XML format
if !message.is_tool_call() {
if let Some(MessageContent::Text(text_content)) = message.content.first() {
let content = &text_content.text;
if let Some(calls_start) = content.find("<function_calls>") {
if let Some(calls_end) = content.find("</function_calls>") {
let calls_text = &content[calls_start..=calls_end + 15];

// Extract the invoke block
if let Some(invoke_start) = calls_text.find("<invoke") {
if let Some(invoke_end) = calls_text[invoke_start..].find("</invoke>") {
let invoke_text =
&calls_text[invoke_start..invoke_start + invoke_end + 9];

// Parse name and parameters
if let Some(name_start) = invoke_text.find("name=\"") {
if let Some(name_end) = invoke_text[name_start + 6..].find("\"")
{
let name = invoke_text
[name_start + 6..name_start + 6 + name_end]
.to_string();

// Build parameters map
let mut parameters = serde_json::Map::new();
let mut param_pos = 0;
while let Some(param_start) =
invoke_text[param_pos..].find("<parameter")
{
if let Some(param_end) = invoke_text
[param_pos + param_start..]
.find("</parameter>")
{
let param_text = &invoke_text[param_pos
+ param_start
..param_pos + param_start + param_end + 11];

if let Some(param_name_start) =
param_text.find("name=\"")
{
if let Some(param_name_end) = param_text
[param_name_start + 6..]
.find("\"")
{
let param_name = &param_text
[param_name_start + 6
..param_name_start
+ 6
+ param_name_end];

if let Some(value_start) =
param_text.find(">")
{
if let Some(value_end) = param_text
[value_start + 1..]
.find("<")
{
let param_value = &param_text
[value_start + 1
..value_start
+ 1
+ value_end];
parameters.insert(
param_name.to_string(),
Value::String(
param_value.to_string(),
),
);
}
}
}
}
param_pos += param_start + param_end + 11;
} else {
break;
}
}

// Create tool request
message.content.clear();
message.content.push(MessageContent::tool_request(
"1",
Ok(ToolCall {
name,
arguments: serde_json::to_value(parameters)
.unwrap_or_default(),
}),
));
}
}
}
}
}
}
}
}

let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {