From 6f0e6d5dbed6d0532cc3a9542de65e920f9362bb Mon Sep 17 00:00:00 2001 From: marcelle <1852848+laanak08@users.noreply.github.com> Date: Mon, 20 Jan 2025 22:29:23 -0500 Subject: [PATCH] feat: truncate message pairs (#659) --- crates/goose-cli/src/test_helpers.rs | 18 -- crates/goose/src/agents/system.rs | 2 +- crates/goose/src/agents/truncate.rs | 251 +++++++++++++++++++++------ 3 files changed, 201 insertions(+), 70 deletions(-) diff --git a/crates/goose-cli/src/test_helpers.rs b/crates/goose-cli/src/test_helpers.rs index e4d9ca80c..011081ec6 100644 --- a/crates/goose-cli/src/test_helpers.rs +++ b/crates/goose-cli/src/test_helpers.rs @@ -18,24 +18,6 @@ pub fn run_with_tmp_dir T, T>(func: F) -> T { ) } -#[cfg(test)] -pub fn run_profile_with_tmp_dir T, T>(profile: &str, func: F) -> T { - use std::ffi::OsStr; - use tempfile::tempdir; - - let temp_dir = tempdir().unwrap(); - let temp_dir_path = temp_dir.path().to_path_buf(); - setup_profile(&temp_dir_path, Some(profile)); - - temp_env::with_vars( - [ - ("HOME", Some(temp_dir_path.as_os_str())), - ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), - ], - func, - ) -} - #[cfg(test)] #[allow(dead_code)] pub async fn run_with_tmp_dir_async(func: F) -> T diff --git a/crates/goose/src/agents/system.rs b/crates/goose/src/agents/system.rs index ab53457d5..630c4ffad 100644 --- a/crates/goose/src/agents/system.rs +++ b/crates/goose/src/agents/system.rs @@ -11,7 +11,7 @@ pub enum SystemError { Initialization(SystemConfig), #[error("Failed a client call to an MCP server: {0}")] Client(#[from] ClientError), - #[error("Messages exceeded context-limit and could not be truncated to fit.")] + #[error("User Message exceeded context-limit. History could not be truncated to accomodate.")] ContextLimit, #[error("Transport error: {0}")] Transport(#[from] mcp_client::transport::Error), diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 40b9bfdbe..1b31811b4 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -12,13 +12,13 @@ use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; -use mcp_core::Tool; +use mcp_core::{Role, Tool}; use serde_json::Value; /// Agent impl. that truncates oldest messages when payload over LLM ctx-limit pub struct TruncateAgent { capabilities: Mutex, - _token_counter: TokenCounter, + token_counter: TokenCounter, } impl TruncateAgent { @@ -26,11 +26,11 @@ impl TruncateAgent { let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); Self { capabilities: Mutex::new(Capabilities::new(provider)), - _token_counter: token_counter, + token_counter, } } - async fn prepare_inference( + async fn enforce_ctx_limit( &self, system_prompt: &str, tools: &[Tool], @@ -45,22 +45,12 @@ impl TruncateAgent { .collect(); let approx_count = - self._token_counter + self.token_counter .count_everything(system_prompt, messages, tools, &resources); let mut new_messages = messages.to_vec(); if approx_count > target_limit { - let user_msg_size = self.text_content_size(new_messages.last()); - if user_msg_size > target_limit { - debug!( - "[WARNING] User message {} exceeds token budget {}.", - user_msg_size, - user_msg_size - target_limit - ); - return Err(SystemError::ContextLimit); - } - - new_messages = self.chop_front_messages(messages, approx_count, target_limit); + new_messages = self.drop_messages(messages, approx_count, target_limit); if new_messages.is_empty() { return Err(SystemError::ContextLimit); } @@ -70,19 +60,20 @@ impl TruncateAgent { } fn text_content_size(&self, message: Option<&Message>) -> usize { - let text = message - .and_then(|msg| msg.content.first()) - .and_then(|c| c.as_text()); - - if let Some(txt) = text { - let count = self._token_counter.count_tokens(txt); - return count; + if let Some(msg) = message { + let mut approx_count = 0; + for content in msg.content.iter() { + if let Some(content_text) = content.as_text() { + approx_count += self.token_counter.count_tokens(content_text); + } + } + return approx_count; } 0 } - fn chop_front_messages( + fn drop_messages( &self, messages: &[Message], approx_count: usize, @@ -95,23 +86,29 @@ impl TruncateAgent { approx_count - target_limit ); - let mut trimmed_items: VecDeque = VecDeque::from(messages.to_vec()); + let user_msg_size = self.text_content_size(messages.last()); + if messages.last().unwrap().role == Role::User && user_msg_size > target_limit { + debug!( + "[WARNING] User message {} exceeds token budget {}.", + user_msg_size, + user_msg_size - target_limit + ); + return Vec::new(); + } + + let mut truncated_conv: VecDeque = VecDeque::from(messages.to_vec()); let mut current_tokens = approx_count; - // Remove messages until we're under target limit - for msg in messages.iter() { - if current_tokens < target_limit || trimmed_items.is_empty() { - break; - } - let count = self.text_content_size(Some(msg)); - let _ = trimmed_items.pop_front().unwrap(); - // Subtract removed message’s token_count - current_tokens = current_tokens.saturating_sub(count); - } + while current_tokens > target_limit && truncated_conv.len() > 1 { + let user_msg = truncated_conv.pop_front().unwrap(); + let user_msg_size = self.text_content_size(Some(&user_msg)); + let assistant_msg = truncated_conv.pop_front().unwrap(); + let assistant_msg_size = self.text_content_size(Some(&assistant_msg)); - // use trimmed message-history + current_tokens = current_tokens.saturating_sub(user_msg_size + assistant_msg_size); + } - Vec::from(trimmed_items) + Vec::from(truncated_conv) } } @@ -126,12 +123,11 @@ impl Agent for TruncateAgent { let mut capabilities = self.capabilities.lock().await; let tools = capabilities.get_prefixed_tools().await?; let system_prompt = capabilities.get_system_prompt().await; - let _estimated_limit = capabilities + let estimated_limit = capabilities .provider() .get_model_config() .get_estimated_limit(); - // Set the user_message field in the span instead of creating a new event if let Some(content) = messages .last() .and_then(|msg| msg.content.first()) @@ -140,20 +136,30 @@ impl Agent for TruncateAgent { debug!("user_message" = &content); } - // Update conversation history for the start of the reply let mut messages = self - .prepare_inference( + .enforce_ctx_limit( &system_prompt, &tools, messages, - _estimated_limit, + estimated_limit, &mut capabilities.get_resources().await?, ) .await?; Ok(Box::pin(async_stream::try_stream! { let _reply_guard = reply_span.enter(); + loop { + messages = self + .enforce_ctx_limit( + &system_prompt, + &tools, + &messages, + estimated_limit, + &mut capabilities.get_resources().await?, + ) + .await?; + // Get completion from provider let (response, usage) = capabilities.provider().complete( &system_prompt, @@ -197,16 +203,17 @@ impl Agent for TruncateAgent { ); } - yield message_tool_response.clone(); - - messages = self.prepare_inference( - &system_prompt, - &tools, - &messages, - _estimated_limit, - &mut capabilities.get_resources().await? - ).await?; + let tool_resp_size = self.text_content_size( + Some(&message_tool_response), + ); + if tool_resp_size > estimated_limit { + // don't push assistant response or tool_response into history + // last message is `user message => tool call`, remove it from history too + messages.pop(); + continue; + } + yield message_tool_response.clone(); messages.push(response); messages.push(message_tool_response); } @@ -246,3 +253,145 @@ impl Agent for TruncateAgent { } register_agent!("truncate", TruncateAgent); + +#[cfg(test)] +mod tests { + use crate::agents::truncate::TruncateAgent; + use crate::message::Message; + use crate::providers::base::{Provider, ProviderUsage, Usage}; + use crate::providers::configs::ModelConfig; + use mcp_core::{Content, Tool}; + use std::iter; + + // Mock Provider implementation for testing + #[derive(Clone)] + struct MockProvider { + model_config: ModelConfig, + } + + #[async_trait::async_trait] + impl Provider for MockProvider { + fn get_model_config(&self) -> &ModelConfig { + &self.model_config + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage)> { + Ok(( + Message::assistant().with_text("Mock response"), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } + + fn get_usage(&self, _data: &serde_json::Value) -> anyhow::Result { + Ok(Usage::new(None, None, None)) + } + } + + const SMALL_MESSAGE: &str = "This is a test, this is just a test, this is only a test.\n"; + + async fn call_enforce_ctx_limit(conversation: &[Message]) -> anyhow::Result> { + let mock_model_config = + ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into()); + let provider = Box::new(MockProvider { + model_config: mock_model_config, + }); + let agent = TruncateAgent::new(provider); + + let mut capabilities = agent.capabilities.lock().await; + let tools = capabilities.get_prefixed_tools().await?; + let system_prompt = capabilities.get_system_prompt().await; + let estimated_limit = capabilities + .provider() + .get_model_config() + .get_estimated_limit(); + + let messages = agent + .enforce_ctx_limit( + &system_prompt, + &tools, + conversation, + estimated_limit, + &mut capabilities.get_resources().await?, + ) + .await?; + + Ok(messages) + } + + fn create_basic_valid_conversation( + interactions_count: usize, + is_tool_use: bool, + ) -> Vec { + let mut conversation = Vec::::new(); + + if is_tool_use { + (0..interactions_count).for_each(|i| { + let tool_output = format!("{:?}{}", SMALL_MESSAGE, i); + conversation.push( + Message::user() + .with_tool_response("id:0", Ok(vec![Content::text(tool_output)])), + ); + conversation.push(Message::assistant().with_text(format!( + "{:?}{}", + SMALL_MESSAGE, + i + 1 + ))); + }); + } else { + (0..interactions_count).for_each(|i| { + conversation.push(Message::user().with_text(format!("{:?}{}", SMALL_MESSAGE, i))); + conversation.push(Message::assistant().with_text(format!( + "{:?}{}", + SMALL_MESSAGE, + i + 1 + ))); + }); + } + + conversation + } + #[tokio::test] + async fn test_simple_conversation_no_truncation() -> anyhow::Result<()> { + let conversation = create_basic_valid_conversation(1, false); + let messages = call_enforce_ctx_limit(&conversation).await?; + assert_eq!(messages.len(), conversation.len()); + Ok(()) + } + #[tokio::test] + async fn test_truncation_when_conversation_history_too_big() -> anyhow::Result<()> { + let conversation = create_basic_valid_conversation(5000, false); + let messages = call_enforce_ctx_limit(&*conversation).await?; + assert_eq!(conversation.len() > messages.len(), true); + assert_eq!(messages.len() > 0, true); + Ok(()) + } + + #[tokio::test] + async fn test_truncation_when_single_user_message_too_big() -> anyhow::Result<()> { + let oversized_message: String = iter::repeat(SMALL_MESSAGE) + .take(10000) + .collect::>() + .join(""); + let mut conversation = create_basic_valid_conversation(3, false); + conversation.push(Message::user().with_text(oversized_message)); + + let messages = call_enforce_ctx_limit(&*conversation).await; + + assert!(matches!(messages, Err(_, ..))); + Ok(()) + } + + #[tokio::test] + async fn test_truncation_when_tool_response_set_too_big() -> anyhow::Result<()> { + let conversation = create_basic_valid_conversation(5000, true); + let messages = call_enforce_ctx_limit(&*conversation).await?; + assert_eq!(conversation.len() > messages.len(), true); + assert_eq!(messages.len() > 0, true); + Ok(()) + } +}