Skip to content

Commit

Permalink
feat: truncate message pairs (#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
laanak08 authored Jan 21, 2025
1 parent 23987f7 commit 6f0e6d5
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 70 deletions.
18 changes: 0 additions & 18 deletions crates/goose-cli/src/test_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,6 @@ pub fn run_with_tmp_dir<F: FnOnce() -> T, T>(func: F) -> T {
)
}

#[cfg(test)]
pub fn run_profile_with_tmp_dir<F: FnOnce() -> T, T>(profile: &str, func: F) -> T {
use std::ffi::OsStr;
use tempfile::tempdir;

let temp_dir = tempdir().unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
setup_profile(&temp_dir_path, Some(profile));

temp_env::with_vars(
[
("HOME", Some(temp_dir_path.as_os_str())),
("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))),
],
func,
)
}

#[cfg(test)]
#[allow(dead_code)]
pub async fn run_with_tmp_dir_async<F, Fut, T>(func: F) -> T
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum SystemError {
Initialization(SystemConfig),
#[error("Failed a client call to an MCP server: {0}")]
Client(#[from] ClientError),
#[error("Messages exceeded context-limit and could not be truncated to fit.")]
#[error("User Message exceeded context-limit. History could not be truncated to accomodate.")]
ContextLimit,
#[error("Transport error: {0}")]
Transport(#[from] mcp_client::transport::Error),
Expand Down
251 changes: 200 additions & 51 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
use crate::register_agent;
use crate::token_counter::TokenCounter;
use mcp_core::Tool;
use mcp_core::{Role, Tool};
use serde_json::Value;

/// Agent impl. that truncates oldest messages when payload over LLM ctx-limit
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
_token_counter: TokenCounter,
token_counter: TokenCounter,
}

impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
_token_counter: token_counter,
token_counter,
}
}

async fn prepare_inference(
async fn enforce_ctx_limit(
&self,
system_prompt: &str,
tools: &[Tool],
Expand All @@ -45,22 +45,12 @@ impl TruncateAgent {
.collect();

let approx_count =
self._token_counter
self.token_counter
.count_everything(system_prompt, messages, tools, &resources);

let mut new_messages = messages.to_vec();
if approx_count > target_limit {
let user_msg_size = self.text_content_size(new_messages.last());
if user_msg_size > target_limit {
debug!(
"[WARNING] User message {} exceeds token budget {}.",
user_msg_size,
user_msg_size - target_limit
);
return Err(SystemError::ContextLimit);
}

new_messages = self.chop_front_messages(messages, approx_count, target_limit);
new_messages = self.drop_messages(messages, approx_count, target_limit);
if new_messages.is_empty() {
return Err(SystemError::ContextLimit);
}
Expand All @@ -70,19 +60,20 @@ impl TruncateAgent {
}

fn text_content_size(&self, message: Option<&Message>) -> usize {
let text = message
.and_then(|msg| msg.content.first())
.and_then(|c| c.as_text());

if let Some(txt) = text {
let count = self._token_counter.count_tokens(txt);
return count;
if let Some(msg) = message {
let mut approx_count = 0;
for content in msg.content.iter() {
if let Some(content_text) = content.as_text() {
approx_count += self.token_counter.count_tokens(content_text);
}
}
return approx_count;
}

0
}

fn chop_front_messages(
fn drop_messages(
&self,
messages: &[Message],
approx_count: usize,
Expand All @@ -95,23 +86,29 @@ impl TruncateAgent {
approx_count - target_limit
);

let mut trimmed_items: VecDeque<Message> = VecDeque::from(messages.to_vec());
let user_msg_size = self.text_content_size(messages.last());
if messages.last().unwrap().role == Role::User && user_msg_size > target_limit {
debug!(
"[WARNING] User message {} exceeds token budget {}.",
user_msg_size,
user_msg_size - target_limit
);
return Vec::new();
}

let mut truncated_conv: VecDeque<Message> = VecDeque::from(messages.to_vec());
let mut current_tokens = approx_count;

// Remove messages until we're under target limit
for msg in messages.iter() {
if current_tokens < target_limit || trimmed_items.is_empty() {
break;
}
let count = self.text_content_size(Some(msg));
let _ = trimmed_items.pop_front().unwrap();
// Subtract removed message’s token_count
current_tokens = current_tokens.saturating_sub(count);
}
while current_tokens > target_limit && truncated_conv.len() > 1 {
let user_msg = truncated_conv.pop_front().unwrap();
let user_msg_size = self.text_content_size(Some(&user_msg));
let assistant_msg = truncated_conv.pop_front().unwrap();
let assistant_msg_size = self.text_content_size(Some(&assistant_msg));

// use trimmed message-history
current_tokens = current_tokens.saturating_sub(user_msg_size + assistant_msg_size);
}

Vec::from(trimmed_items)
Vec::from(truncated_conv)
}
}

Expand All @@ -126,12 +123,11 @@ impl Agent for TruncateAgent {
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let system_prompt = capabilities.get_system_prompt().await;
let _estimated_limit = capabilities
let estimated_limit = capabilities
.provider()
.get_model_config()
.get_estimated_limit();

// Set the user_message field in the span instead of creating a new event
if let Some(content) = messages
.last()
.and_then(|msg| msg.content.first())
Expand All @@ -140,20 +136,30 @@ impl Agent for TruncateAgent {
debug!("user_message" = &content);
}

// Update conversation history for the start of the reply
let mut messages = self
.prepare_inference(
.enforce_ctx_limit(
&system_prompt,
&tools,
messages,
_estimated_limit,
estimated_limit,
&mut capabilities.get_resources().await?,
)
.await?;

Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();

loop {
messages = self
.enforce_ctx_limit(
&system_prompt,
&tools,
&messages,
estimated_limit,
&mut capabilities.get_resources().await?,
)
.await?;

// Get completion from provider
let (response, usage) = capabilities.provider().complete(
&system_prompt,
Expand Down Expand Up @@ -197,16 +203,17 @@ impl Agent for TruncateAgent {
);
}

yield message_tool_response.clone();

messages = self.prepare_inference(
&system_prompt,
&tools,
&messages,
_estimated_limit,
&mut capabilities.get_resources().await?
).await?;
let tool_resp_size = self.text_content_size(
Some(&message_tool_response),
);
if tool_resp_size > estimated_limit {
// don't push assistant response or tool_response into history
// last message is `user message => tool call`, remove it from history too
messages.pop();
continue;
}

yield message_tool_response.clone();
messages.push(response);
messages.push(message_tool_response);
}
Expand Down Expand Up @@ -246,3 +253,145 @@ impl Agent for TruncateAgent {
}

register_agent!("truncate", TruncateAgent);

#[cfg(test)]
mod tests {
use crate::agents::truncate::TruncateAgent;
use crate::message::Message;
use crate::providers::base::{Provider, ProviderUsage, Usage};
use crate::providers::configs::ModelConfig;
use mcp_core::{Content, Tool};
use std::iter;

// Mock Provider implementation for testing
#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}

#[async_trait::async_trait]
impl Provider for MockProvider {
fn get_model_config(&self) -> &ModelConfig {
&self.model_config
}

async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage)> {
Ok((
Message::assistant().with_text("Mock response"),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}

fn get_usage(&self, _data: &serde_json::Value) -> anyhow::Result<Usage> {
Ok(Usage::new(None, None, None))
}
}

const SMALL_MESSAGE: &str = "This is a test, this is just a test, this is only a test.\n";

async fn call_enforce_ctx_limit(conversation: &[Message]) -> anyhow::Result<Vec<Message>> {
let mock_model_config =
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
let provider = Box::new(MockProvider {
model_config: mock_model_config,
});
let agent = TruncateAgent::new(provider);

let mut capabilities = agent.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let system_prompt = capabilities.get_system_prompt().await;
let estimated_limit = capabilities
.provider()
.get_model_config()
.get_estimated_limit();

let messages = agent
.enforce_ctx_limit(
&system_prompt,
&tools,
conversation,
estimated_limit,
&mut capabilities.get_resources().await?,
)
.await?;

Ok(messages)
}

fn create_basic_valid_conversation(
interactions_count: usize,
is_tool_use: bool,
) -> Vec<Message> {
let mut conversation = Vec::<Message>::new();

if is_tool_use {
(0..interactions_count).for_each(|i| {
let tool_output = format!("{:?}{}", SMALL_MESSAGE, i);
conversation.push(
Message::user()
.with_tool_response("id:0", Ok(vec![Content::text(tool_output)])),
);
conversation.push(Message::assistant().with_text(format!(
"{:?}{}",
SMALL_MESSAGE,
i + 1
)));
});
} else {
(0..interactions_count).for_each(|i| {
conversation.push(Message::user().with_text(format!("{:?}{}", SMALL_MESSAGE, i)));
conversation.push(Message::assistant().with_text(format!(
"{:?}{}",
SMALL_MESSAGE,
i + 1
)));
});
}

conversation
}
#[tokio::test]
async fn test_simple_conversation_no_truncation() -> anyhow::Result<()> {
let conversation = create_basic_valid_conversation(1, false);
let messages = call_enforce_ctx_limit(&conversation).await?;
assert_eq!(messages.len(), conversation.len());
Ok(())
}
#[tokio::test]
async fn test_truncation_when_conversation_history_too_big() -> anyhow::Result<()> {
let conversation = create_basic_valid_conversation(5000, false);
let messages = call_enforce_ctx_limit(&*conversation).await?;
assert_eq!(conversation.len() > messages.len(), true);
assert_eq!(messages.len() > 0, true);
Ok(())
}

#[tokio::test]
async fn test_truncation_when_single_user_message_too_big() -> anyhow::Result<()> {
let oversized_message: String = iter::repeat(SMALL_MESSAGE)
.take(10000)
.collect::<Vec<&str>>()
.join("");
let mut conversation = create_basic_valid_conversation(3, false);
conversation.push(Message::user().with_text(oversized_message));

let messages = call_enforce_ctx_limit(&*conversation).await;

assert!(matches!(messages, Err(_, ..)));
Ok(())
}

#[tokio::test]
async fn test_truncation_when_tool_response_set_too_big() -> anyhow::Result<()> {
let conversation = create_basic_valid_conversation(5000, true);
let messages = call_enforce_ctx_limit(&*conversation).await?;
assert_eq!(conversation.len() > messages.len(), true);
assert_eq!(messages.len() > 0, true);
Ok(())
}
}

0 comments on commit 6f0e6d5

Please sign in to comment.