From 6220ef054f55f044b9bf6923c38783faed8b7023 Mon Sep 17 00:00:00 2001 From: Bradley Axen Date: Tue, 11 Feb 2025 20:03:32 -0800 Subject: [PATCH 1/5] feat: Support extending the system prompt (#1167) --- crates/goose-cli/src/cli_prompt.rs | 16 +++++++++++ crates/goose-cli/src/commands/session.rs | 5 ++++ crates/goose-cli/src/main.rs | 1 + crates/goose-server/src/routes/agent.rs | 35 ++++++++++++++++++++++++ crates/goose/src/agents/agent.rs | 3 ++ crates/goose/src/agents/capabilities.rs | 19 ++++++++++++- crates/goose/src/agents/reference.rs | 5 ++++ crates/goose/src/agents/truncate.rs | 5 ++++ 8 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 crates/goose-cli/src/cli_prompt.rs diff --git a/crates/goose-cli/src/cli_prompt.rs b/crates/goose-cli/src/cli_prompt.rs new file mode 100644 index 000000000..9c0db7bc1 --- /dev/null +++ b/crates/goose-cli/src/cli_prompt.rs @@ -0,0 +1,16 @@ +/// Returns a system prompt extension that explains CLI-specific functionality +pub fn get_cli_prompt() -> String { + String::from( + "You are being accessed through a command-line interface. The following slash commands are available +- you can let the user know about them if they need help: + +- /exit or /quit - Exit the session +- /t - Toggle between Light/Dark/Ansi themes +- /? or /help - Display help message + +Additional keyboard shortcuts: +- Ctrl+C - Interrupt the current interaction (resets to before the interrupted request) +- Ctrl+J - Add a newline +- Up/Down arrows - Navigate command history" + ) +} diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 0273260a6..0fb07028d 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -161,6 +161,11 @@ pub async fn build_session( let prompt = Box::new(RustylinePrompt::new()); + // Add CLI-specific system prompt extension + agent + .extend_system_prompt(crate::cli_prompt::get_cli_prompt()) + .await; + display_session_info(resume, &provider_name, &model, &session_file); Session::new(agent, prompt, session_file) } diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index bfe469777..8af76d919 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -9,6 +9,7 @@ pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { app_name: "goose".to_string(), }); +mod cli_prompt; mod commands; mod log_usage; mod logging; diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index dfd683075..46bdb1bd4 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -17,6 +17,16 @@ struct VersionsResponse { default_version: String, } +#[derive(Deserialize)] +struct ExtendPromptRequest { + extension: String, +} + +#[derive(Serialize)] +struct ExtendPromptResponse { + success: bool, +} + #[derive(Deserialize)] struct CreateAgentRequest { version: Option, @@ -61,6 +71,30 @@ async fn get_versions() -> Json { }) } +async fn extend_prompt( + State(state): State, + headers: HeaderMap, + Json(payload): Json, +) -> Result, StatusCode> { + // Verify secret key + let secret_key = headers + .get("X-Secret-Key") + .and_then(|value| value.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if secret_key != state.secret_key { + return Err(StatusCode::UNAUTHORIZED); + } + + let mut agent = state.agent.lock().await; + if let Some(ref mut agent) = *agent { + agent.extend_system_prompt(payload.extension).await; + Ok(Json(ExtendPromptResponse { success: true })) + } else { + Err(StatusCode::NOT_FOUND) + } +} + async fn create_agent( State(state): State, headers: HeaderMap, @@ -132,6 +166,7 @@ pub fn routes(state: AppState) -> Router { Router::new() .route("/agent/versions", get(get_versions)) .route("/agent/providers", get(list_providers)) + .route("/agent/prompt", post(extend_prompt)) .route("/agent", post(create_agent)) .with_state(state) } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 6a595ffb6..ca681de4d 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -28,4 +28,7 @@ pub trait Agent: Send + Sync { /// Get the total usage of the agent async fn usage(&self) -> Vec; + + /// Add custom text to be included in the system prompt + async fn extend_system_prompt(&mut self, extension: String); } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 7fcd6d20e..e9a876468 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -30,6 +30,7 @@ pub struct Capabilities { resource_capable_extensions: HashSet, provider: Box, provider_usage: Mutex>, + system_prompt_extensions: Vec, } /// A flattened representation of a resource used by the agent to prepare inference @@ -88,6 +89,7 @@ impl Capabilities { resource_capable_extensions: HashSet::new(), provider, provider_usage: Mutex::new(Vec::new()), + system_prompt_extensions: Vec::new(), } } @@ -164,6 +166,11 @@ impl Capabilities { Ok(()) } + /// Add a system prompt extension + pub fn add_system_prompt_extension(&mut self, extension: String) { + self.system_prompt_extensions.push(extension); + } + /// Get a reference to the provider pub fn provider(&self) -> &dyn Provider { &*self.provider @@ -303,7 +310,17 @@ impl Capabilities { context.insert("extensions", serde_json::to_value(extensions_info).unwrap()); context.insert("current_date_time", Value::String(current_date_time)); - load_prompt_file("system.md", &context).expect("Prompt should render") + let base_prompt = load_prompt_file("system.md", &context).expect("Prompt should render"); + + if self.system_prompt_extensions.is_empty() { + base_prompt + } else { + format!( + "{}\n\n# Additional Instructions:\n\n{}", + base_prompt, + self.system_prompt_extensions.join("\n\n") + ) + } } /// Find and return a reference to the appropriate client for a tool call diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 5d4cf86bb..0cb2f67e9 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -184,6 +184,11 @@ impl Agent for ReferenceAgent { let capabilities = self.capabilities.lock().await; capabilities.get_usage().await } + + async fn extend_system_prompt(&mut self, extension: String) { + let mut capabilities = self.capabilities.lock().await; + capabilities.add_system_prompt_extension(extension); + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 5f80325dc..d6df60cb7 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -292,6 +292,11 @@ impl Agent for TruncateAgent { let capabilities = self.capabilities.lock().await; capabilities.get_usage().await } + + async fn extend_system_prompt(&mut self, extension: String) { + let mut capabilities = self.capabilities.lock().await; + capabilities.add_system_prompt_extension(extension); + } } register_agent!("truncate", TruncateAgent); From 7a8552ed9efba8b1645fa3dcaf37b889785c492e Mon Sep 17 00:00:00 2001 From: Bradley Axen Date: Tue, 11 Feb 2025 21:16:58 -0800 Subject: [PATCH 2/5] feat: simplify CLI sessions (#1168) --- crates/goose-cli/src/commands/mod.rs | 1 - crates/goose-cli/src/commands/session.rs | 192 -------- crates/goose-cli/src/lib.rs | 15 + crates/goose-cli/src/log_usage.rs | 15 +- crates/goose-cli/src/main.rs | 30 +- crates/goose-cli/src/prompt.rs | 39 -- crates/goose-cli/src/prompt/renderer.rs | 408 ---------------- crates/goose-cli/src/prompt/rustyline.rs | 176 ------- crates/goose-cli/src/session.rs | 362 -------------- crates/goose-cli/src/session/builder.rs | 140 ++++++ crates/goose-cli/src/session/input.rs | 152 ++++++ crates/goose-cli/src/session/mod.rs | 302 ++++++++++++ crates/goose-cli/src/session/output.rs | 457 ++++++++++++++++++ .../src/{cli_prompt.rs => session/prompt.rs} | 0 crates/goose-cli/src/session/storage.rs | 165 +++++++ .../src/{prompt => session}/thinking.rs | 6 +- crates/goose-cli/src/test_helpers.rs | 70 --- 17 files changed, 1248 insertions(+), 1282 deletions(-) delete mode 100644 crates/goose-cli/src/commands/session.rs create mode 100644 crates/goose-cli/src/lib.rs delete mode 100644 crates/goose-cli/src/prompt.rs delete mode 100644 crates/goose-cli/src/prompt/renderer.rs delete mode 100644 crates/goose-cli/src/prompt/rustyline.rs delete mode 100644 crates/goose-cli/src/session.rs create mode 100644 crates/goose-cli/src/session/builder.rs create mode 100644 crates/goose-cli/src/session/input.rs create mode 100644 crates/goose-cli/src/session/mod.rs create mode 100644 crates/goose-cli/src/session/output.rs rename crates/goose-cli/src/{cli_prompt.rs => session/prompt.rs} (100%) create mode 100644 crates/goose-cli/src/session/storage.rs rename crates/goose-cli/src/{prompt => session}/thinking.rs (98%) delete mode 100644 crates/goose-cli/src/test_helpers.rs diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index 6c3e29df1..e9ed50ce5 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,4 +1,3 @@ pub mod agent_version; pub mod configure; pub mod mcp; -pub mod session; diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs deleted file mode 100644 index 0fb07028d..000000000 --- a/crates/goose-cli/src/commands/session.rs +++ /dev/null @@ -1,192 +0,0 @@ -use rand::{distributions::Alphanumeric, Rng}; -use std::process; - -use crate::prompt::rustyline::RustylinePrompt; -use crate::session::{ensure_session_dir, get_most_recent_session, legacy_session_dir, Session}; -use console::style; -use goose::agents::extension::{Envs, ExtensionError}; -use goose::agents::AgentFactory; -use goose::config::{Config, ExtensionConfig, ExtensionManager}; -use goose::providers::create; -use std::path::Path; - -use mcp_client::transport::Error as McpClientError; - -pub async fn build_session( - name: Option, - resume: bool, - extensions: Vec, - builtins: Vec, -) -> Session<'static> { - // Load config and get provider/model - let config = Config::global(); - - let provider_name: String = config - .get("GOOSE_PROVIDER") - .expect("No provider configured. Run 'goose configure' first"); - let session_dir = ensure_session_dir().expect("Failed to create session directory"); - - let model: String = config - .get("GOOSE_MODEL") - .expect("No model configured. Run 'goose configure' first"); - let model_config = goose::model::ModelConfig::new(model.clone()); - let provider = create(&provider_name, model_config).expect("Failed to create provider"); - - // Create the agent - let agent_version: Option = config.get("GOOSE_AGENT").ok(); - let mut agent = match agent_version { - Some(version) => AgentFactory::create(&version, provider), - None => AgentFactory::create(AgentFactory::default_version(), provider), - } - .expect("Failed to create agent"); - - // Setup extensions for the agent - for extension in ExtensionManager::get_all().expect("should load extensions") { - if extension.enabled { - let config = extension.config.clone(); - agent - .add_extension(config.clone()) - .await - .unwrap_or_else(|e| { - let err = match e { - ExtensionError::Transport(McpClientError::StdioProcessError(inner)) => { - inner - } - _ => e.to_string(), - }; - println!("Failed to start extension: {}, {:?}", config.name(), err); - println!( - "Please check extension configuration for {}.", - config.name() - ); - process::exit(1); - }); - } - } - - // Add extensions if provided - for extension_str in extensions { - let mut parts: Vec<&str> = extension_str.split_whitespace().collect(); - let mut envs = std::collections::HashMap::new(); - - // Parse environment variables (format: KEY=value) - while let Some(part) = parts.first() { - if !part.contains('=') { - break; - } - let env_part = parts.remove(0); - let (key, value) = env_part.split_once('=').unwrap(); - envs.insert(key.to_string(), value.to_string()); - } - - if parts.is_empty() { - eprintln!("No command provided in extension string"); - process::exit(1); - } - - let cmd = parts.remove(0).to_string(); - //this is an ephemeral extension so name does not matter - let name = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(8) - .map(char::from) - .collect(); - let config = ExtensionConfig::Stdio { - name, - cmd, - args: parts.iter().map(|s| s.to_string()).collect(), - envs: Envs::new(envs), - }; - - agent.add_extension(config).await.unwrap_or_else(|e| { - eprintln!("Failed to start extension: {}", e); - process::exit(1); - }); - } - - // Add builtin extensions - for name in builtins { - let config = ExtensionConfig::Builtin { name }; - agent.add_extension(config).await.unwrap_or_else(|e| { - eprintln!("Failed to start builtin extension: {}", e); - process::exit(1); - }); - } - - // If resuming, try to find the session - if resume { - if let Some(ref session_name) = name { - // Try to resume specific session - let session_file = session_dir.join(format!("{}.jsonl", session_name)); - if session_file.exists() { - let prompt = Box::new(RustylinePrompt::new()); - return Session::new(agent, prompt, session_file); - } - - // LEGACY NOTE: remove this once old paths are no longer needed. - if let Some(legacy_dir) = legacy_session_dir() { - let legacy_file = legacy_dir.join(format!("{}.jsonl", session_name)); - if legacy_file.exists() { - let prompt = Box::new(RustylinePrompt::new()); - return Session::new(agent, prompt, legacy_file); - } - } - - eprintln!("Session '{}' not found, starting new session", session_name); - } else { - // Try to resume most recent session - if let Ok(session_file) = get_most_recent_session() { - let prompt = Box::new(RustylinePrompt::new()); - return Session::new(agent, prompt, session_file); - } else { - eprintln!("No previous sessions found, starting new session"); - } - } - } - - // Generate session name if not provided - let name = name.unwrap_or_else(|| { - rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(8) - .map(char::from) - .collect() - }); - - let session_file = session_dir.join(format!("{}.jsonl", name)); - if session_file.exists() { - eprintln!("Session '{}' already exists", name); - process::exit(1); - } - - let prompt = Box::new(RustylinePrompt::new()); - - // Add CLI-specific system prompt extension - agent - .extend_system_prompt(crate::cli_prompt::get_cli_prompt()) - .await; - - display_session_info(resume, &provider_name, &model, &session_file); - Session::new(agent, prompt, session_file) -} - -fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) { - let start_session_msg = if resume { - "resuming session |" - } else { - "starting session |" - }; - println!( - "{} {} {} {} {}", - style(start_session_msg).dim(), - style("provider:").dim(), - style(provider).cyan().dim(), - style("model:").dim(), - style(model).cyan().dim(), - ); - println!( - " {} {}", - style("logging to").dim(), - style(session_file.display()).dim().cyan(), - ); -} diff --git a/crates/goose-cli/src/lib.rs b/crates/goose-cli/src/lib.rs new file mode 100644 index 000000000..207af8179 --- /dev/null +++ b/crates/goose-cli/src/lib.rs @@ -0,0 +1,15 @@ +use etcetera::AppStrategyArgs; +use once_cell::sync::Lazy; +pub mod commands; +pub mod log_usage; +pub mod logging; +pub mod session; + +// Re-export commonly used types +pub use session::Session; + +pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { + top_level_domain: "Block".to_string(), + author: "Block".to_string(), + app_name: "goose".to_string(), +}); diff --git a/crates/goose-cli/src/log_usage.rs b/crates/goose-cli/src/log_usage.rs index d38b9bbd4..7be55cd9c 100644 --- a/crates/goose-cli/src/log_usage.rs +++ b/crates/goose-cli/src/log_usage.rs @@ -60,16 +60,21 @@ mod tests { use etcetera::{choose_app_strategy, AppStrategy}; use goose::providers::base::{ProviderUsage, Usage}; - use crate::{ - log_usage::{log_usage, SessionLog}, - test_helpers::run_with_tmp_dir, - }; + use crate::log_usage::{log_usage, SessionLog}; + + pub fn run_with_tmp_dir T, T>(func: F) -> T { + use tempfile::tempdir; + + let temp_dir = tempdir().unwrap(); + let temp_dir_path = temp_dir.path().to_path_buf(); + + temp_env::with_vars([("HOME", Some(temp_dir_path.as_os_str()))], func) + } #[test] fn test_session_logging() { run_with_tmp_dir(|| { let home_dir = choose_app_strategy(crate::APP_STRATEGY.clone()).unwrap(); - let log_file = home_dir .in_state_dir("logs") .unwrap_or_else(|| home_dir.in_data_dir("logs")) diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index 8af76d919..792c2fbba 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -1,33 +1,15 @@ use anyhow::Result; use clap::{CommandFactory, Parser, Subcommand}; -use etcetera::AppStrategyArgs; -use once_cell::sync::Lazy; - -pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { - top_level_domain: "Block".to_string(), - author: "Block".to_string(), - app_name: "goose".to_string(), -}); - -mod cli_prompt; -mod commands; -mod log_usage; -mod logging; -mod prompt; -mod session; - -use commands::agent_version::AgentCommand; -use commands::configure::handle_configure; -use commands::mcp::run_server; -use commands::session::build_session; + use console::style; use goose::config::Config; -use logging::setup_logging; +use goose_cli::commands::agent_version::AgentCommand; +use goose_cli::commands::configure::handle_configure; +use goose_cli::commands::mcp::run_server; +use goose_cli::logging::setup_logging; +use goose_cli::session::build_session; use std::io::{self, Read}; -#[cfg(test)] -mod test_helpers; - #[derive(Parser)] #[command(author, version, display_name = "", about, long_about = None)] struct Cli { diff --git a/crates/goose-cli/src/prompt.rs b/crates/goose-cli/src/prompt.rs deleted file mode 100644 index 79d7ae58c..000000000 --- a/crates/goose-cli/src/prompt.rs +++ /dev/null @@ -1,39 +0,0 @@ -use anyhow::Result; -use goose::message::Message; - -pub mod renderer; -pub mod rustyline; -pub mod thinking; - -pub trait Prompt { - fn render(&mut self, message: Box); - fn get_input(&mut self) -> Result; - fn show_busy(&mut self); - fn hide_busy(&self); - fn close(&self); - /// Load the user's message history into the prompt for command history navigation. First message is the oldest message. - /// When history is supported by the prompt. - fn load_user_message_history(&mut self, _messages: Vec) {} - fn goose_ready(&self) { - println!("\n"); - println!("Goose is running! Enter your instructions, or try asking what goose can do."); - println!("\n"); - } -} - -pub struct Input { - pub input_type: InputType, - pub content: Option, // Optional content as sometimes the user may be issuing a command eg. (Exit) -} - -pub enum InputType { - AskAgain, // Ask the user for input again. Control flow command. - Message, // User sent a message - Exit, // User wants to exit the session -} - -pub enum Theme { - Light, - Dark, - Ansi, // Use terminal's ANSI/base16 colors directly. -} diff --git a/crates/goose-cli/src/prompt/renderer.rs b/crates/goose-cli/src/prompt/renderer.rs deleted file mode 100644 index 9d569fa91..000000000 --- a/crates/goose-cli/src/prompt/renderer.rs +++ /dev/null @@ -1,408 +0,0 @@ -use std::collections::HashMap; -use std::io::{self, Write}; -use std::path::PathBuf; - -use bat::WrappingMode; -use console::style; -use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; -use mcp_core::role::Role; -use mcp_core::{content::Content, tool::ToolCall}; -use serde_json::Value; - -use super::Theme; - -const MAX_STRING_LENGTH: usize = 40; -const MAX_PATH_LENGTH: usize = 60; -const INDENT: &str = " "; - -/// Shortens a path string by abbreviating directory names while keeping the last two components intact. -/// If the path starts with the user's home directory, it will be replaced with ~. -/// -/// # Examples -/// ``` -/// let path = "/Users/alice/Development/very/long/path/to/file.txt"; -/// assert_eq!( -/// shorten_path(path), -/// "~/D/v/l/p/to/file.txt" -/// ); -/// ``` -fn shorten_path(path: &str) -> String { - let path = PathBuf::from(path); - - // First try to convert to ~ if it's in home directory - let home = etcetera::home_dir(); - let path_str = if let Ok(home) = home { - if let Ok(stripped) = path.strip_prefix(home) { - format!("~/{}", stripped.display()) - } else { - path.display().to_string() - } - } else { - path.display().to_string() - }; - - // If path is already short enough, return as is - if path_str.len() <= MAX_PATH_LENGTH { - return path_str; - } - - let parts: Vec<_> = path_str.split('/').collect(); - - // If we have 3 or fewer parts, return as is - if parts.len() <= 3 { - return path_str; - } - - // Keep the first component (empty string before root / or ~) and last two components intact - let mut shortened = vec![parts[0].to_string()]; - - // Shorten middle components to their first letter - for component in &parts[1..parts.len() - 2] { - if !component.is_empty() { - shortened.push(component.chars().next().unwrap_or('?').to_string()); - } - } - - // Add the last two components - shortened.push(parts[parts.len() - 2].to_string()); - shortened.push(parts[parts.len() - 1].to_string()); - - shortened.join("/") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_shorten_path() { - // Test a long path without home directory - let long_path = "/Users/test/Development/this/is/a/very/long/nested/deeply/example.txt"; - let shortened = shorten_path(long_path); - assert!( - shortened.len() < long_path.len(), - "Shortened path '{}' should be shorter than original '{}'", - shortened, - long_path - ); - assert!( - shortened.ends_with("deeply/example.txt"), - "Shortened path '{}' should end with 'deeply/example.txt'", - shortened - ); - - // Test a short path (shouldn't be modified) - assert_eq!(shorten_path("/usr/local/bin"), "/usr/local/bin"); - - // Test path with less than 3 components - assert_eq!(shorten_path("/usr/local"), "/usr/local"); - } -} - -/// Implement the ToolRenderer trait for each tool that you want to render in the prompt. -pub trait ToolRenderer: ToolRendererClone { - fn tool_name(&self) -> String; - fn request(&self, tool_request: &ToolRequest, theme: &str); - fn response(&self, tool_response: &ToolResponse, theme: &str); -} - -// Helper trait for cloning boxed ToolRenderer objects -pub trait ToolRendererClone { - fn clone_box(&self) -> Box; -} - -// Implement the helper trait for any type that implements ToolRenderer and Clone -impl ToolRendererClone for T -where - T: 'static + ToolRenderer + Clone, -{ - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } -} - -// Make Box clonable -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_box() - } -} - -#[derive(Clone)] -pub struct DefaultRenderer; - -impl ToolRenderer for DefaultRenderer { - fn tool_name(&self) -> String { - "default".to_string() - } - - fn request(&self, tool_request: &ToolRequest, theme: &str) { - match &tool_request.tool_call { - Ok(call) => { - default_print_request_header(call); - - // Format and print the parameters - print_params(&call.arguments, 0); - print_newline(); - } - Err(e) => print_markdown(&e.to_string(), theme), - } - } - - fn response(&self, tool_response: &ToolResponse, theme: &str) { - default_response_renderer(tool_response, theme); - } -} - -#[derive(Clone)] -pub struct TextEditorRenderer; - -impl ToolRenderer for TextEditorRenderer { - fn tool_name(&self) -> String { - "developer__text_editor".to_string() - } - - fn request(&self, tool_request: &ToolRequest, theme: &str) { - match &tool_request.tool_call { - Ok(call) => { - default_print_request_header(call); - - // Print path first with special formatting - if let Some(Value::String(path)) = call.arguments.get("path") { - println!( - "{}: {}", - style("path").dim(), - style(shorten_path(path)).green() - ); - } - - // Print other arguments normally, excluding path - if let Some(args) = call.arguments.as_object() { - let mut other_args = serde_json::Map::new(); - for (k, v) in args { - if k != "path" { - other_args.insert(k.clone(), v.clone()); - } - } - print_params(&Value::Object(other_args), 0); - } - print_newline(); - } - Err(e) => print_markdown(&e.to_string(), theme), - } - } - - fn response(&self, tool_response: &ToolResponse, theme: &str) { - default_response_renderer(tool_response, theme); - } -} - -#[derive(Clone)] -pub struct BashDeveloperExtensionRenderer; - -impl ToolRenderer for BashDeveloperExtensionRenderer { - fn tool_name(&self) -> String { - "developer__shell".to_string() - } - - fn request(&self, tool_request: &ToolRequest, theme: &str) { - match &tool_request.tool_call { - Ok(call) => { - default_print_request_header(call); - - match call.arguments.get("command") { - Some(Value::String(s)) => { - println!("{}: {}", style("command").dim(), style(s).green()); - } - _ => print_params(&call.arguments, 0), - } - print_newline(); - } - Err(e) => print_markdown(&e.to_string(), theme), - } - } - - fn response(&self, tool_response: &ToolResponse, theme: &str) { - default_response_renderer(tool_response, theme); - } -} - -pub fn render(message: &Message, theme: &Theme, renderers: HashMap>) { - let theme = match theme { - Theme::Light => "GitHub", - Theme::Dark => "zenburn", - Theme::Ansi => "base16", - }; - - let mut last_tool_name: &str = "default"; - for message_content in &message.content { - match message_content { - MessageContent::Text(text) => print_markdown(&text.text, theme), - MessageContent::ToolRequest(tool_request) => match &tool_request.tool_call { - Ok(call) => { - last_tool_name = &call.name; - renderers - .get(&call.name) - .or_else(|| renderers.get("default")) - .unwrap() - .request(tool_request, theme); - } - Err(_) => renderers - .get("default") - .unwrap() - .request(tool_request, theme), - }, - MessageContent::ToolResponse(tool_response) => renderers - .get(last_tool_name) - .or_else(|| renderers.get("default")) - .unwrap() - .response(tool_response, theme), - MessageContent::Image(image) => { - println!("Image: [data: {}, type: {}]", image.data, image.mime_type); - } - } - } - - print_newline(); - io::stdout().flush().expect("Failed to flush stdout"); -} - -pub fn default_response_renderer(tool_response: &ToolResponse, theme: &str) { - match &tool_response.tool_result { - Ok(contents) => { - for content in contents { - if content - .audience() - .is_some_and(|audience| !audience.contains(&Role::User)) - { - continue; - } - - let min_priority = std::env::var("GOOSE_CLI_MIN_PRIORITY") - .ok() - .and_then(|val| val.parse::().ok()) - .unwrap_or(0.0); - - // if priority is not set OR less than or equal to min_priority, do not render - if content - .priority() - .is_some_and(|priority| priority <= min_priority) - || content.priority().is_none() - { - continue; - } - - if let Content::Text(text) = content { - print_markdown(&text.text, theme); - } - } - } - Err(e) => print_markdown(&e.to_string(), theme), - } -} - -pub fn default_print_request_header(call: &ToolCall) { - // Print the tool name with an emoji - - // use rsplit to handle any prefixed tools with more underscores - // unicode gets converted to underscores during sanitization - let parts: Vec<_> = call.name.rsplit("__").collect(); - - let tool_header = format!( - "─── {} | {} ──────────────────────────", - style(parts.first().unwrap_or(&"unknown")), - style( - parts - .split_first() - // client name is the rest of the split, reversed - // reverse the iterator and re-join on __ - .map(|(_, s)| s.iter().rev().copied().collect::>().join("__")) - .unwrap_or_else(|| "unknown".to_string()) - ) - .magenta() - .dim(), - ); - print_newline(); - println!("{}", tool_header); -} - -pub fn print_markdown(content: &str, theme: &str) { - bat::PrettyPrinter::new() - .input(bat::Input::from_bytes(content.as_bytes())) - .theme(theme) - .language("Markdown") - .wrapping_mode(WrappingMode::Character) - .print() - .unwrap(); -} - -/// Format and print parameters recursively with proper indentation and colors -pub fn print_params(value: &Value, depth: usize) { - let indent = INDENT.repeat(depth); - - match value { - Value::Object(map) => { - for (key, val) in map { - match val { - Value::Object(_) => { - println!("{}{}:", indent, style(key).dim()); - print_params(val, depth + 1); - } - Value::Array(arr) => { - println!("{}{}:", indent, style(key).dim()); - for item in arr.iter() { - println!("{}{}- ", indent, INDENT); - print_params(item, depth + 2); - } - } - Value::String(s) => { - if s.len() > MAX_STRING_LENGTH { - println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); - } else { - println!("{}{}: {}", indent, style(key).dim(), style(s).green()); - } - } - Value::Number(n) => { - println!("{}{}: {}", indent, style(key).dim(), style(n).blue()); - } - Value::Bool(b) => { - println!("{}{}: {}", indent, style(key).dim(), style(b).blue()); - } - Value::Null => { - println!("{}{}: {}", indent, style(key).dim(), style("null").dim()); - } - } - } - } - Value::Array(arr) => { - for (i, item) in arr.iter().enumerate() { - println!("{}{}.", indent, i + 1); - print_params(item, depth + 1); - } - } - Value::String(s) => { - if s.len() > MAX_STRING_LENGTH { - println!( - "{}{}", - indent, - style(format!("[REDACTED: {} chars]", s.len())).yellow() - ); - } else { - println!("{}{}", indent, style(s).green()); - } - } - Value::Number(n) => { - println!("{}{}", indent, style(n).yellow()); - } - Value::Bool(b) => { - println!("{}{}", indent, style(b).yellow()); - } - Value::Null => { - println!("{}{}", indent, style("null").dim()); - } - } -} - -pub fn print_newline() { - println!(); -} diff --git a/crates/goose-cli/src/prompt/rustyline.rs b/crates/goose-cli/src/prompt/rustyline.rs deleted file mode 100644 index 19e6c9713..000000000 --- a/crates/goose-cli/src/prompt/rustyline.rs +++ /dev/null @@ -1,176 +0,0 @@ -use std::collections::HashMap; - -use super::{ - renderer::{ - render, BashDeveloperExtensionRenderer, DefaultRenderer, TextEditorRenderer, ToolRenderer, - }, - thinking::get_random_thinking_message, - Input, InputType, Prompt, Theme, -}; - -use anyhow::Result; -use cliclack::spinner; -use console::style; -use goose::message::Message; -use mcp_core::Role; -use rustyline::{DefaultEditor, EventHandler, KeyCode, KeyEvent, Modifiers}; - -fn get_prompt() -> String { - format!("{} ", style("( O)>").cyan().bold()) -} - -pub struct RustylinePrompt { - spinner: cliclack::ProgressBar, - theme: Theme, - renderers: HashMap>, - editor: DefaultEditor, -} - -impl RustylinePrompt { - pub fn new() -> Self { - let mut renderers: HashMap> = HashMap::new(); - let default_renderer = DefaultRenderer; - renderers.insert(default_renderer.tool_name(), Box::new(default_renderer)); - - let bash_dev_extension_renderer = BashDeveloperExtensionRenderer; - renderers.insert( - bash_dev_extension_renderer.tool_name(), - Box::new(bash_dev_extension_renderer), - ); - - let text_editor_renderer = TextEditorRenderer; - renderers.insert( - text_editor_renderer.tool_name(), - Box::new(text_editor_renderer), - ); - - let mut editor = DefaultEditor::new().expect("Failed to create editor"); - editor.bind_sequence( - KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), - EventHandler::Simple(rustyline::Cmd::Newline), - ); - - RustylinePrompt { - spinner: spinner(), - theme: std::env::var("GOOSE_CLI_THEME") - .ok() - .map(|val| { - if val.eq_ignore_ascii_case("light") { - Theme::Light - } else if val.eq_ignore_ascii_case("ansi") { - Theme::Ansi - } else { - Theme::Dark - } - }) - .unwrap_or(Theme::Dark), - renderers, - editor, - } - } -} - -impl Prompt for RustylinePrompt { - fn render(&mut self, message: Box) { - render(&message, &self.theme, self.renderers.clone()); - } - - fn show_busy(&mut self) { - self.spinner = spinner(); - self.spinner - .start(format!("{}...", get_random_thinking_message())); - } - - fn hide_busy(&self) { - self.spinner.stop(""); - } - - fn get_input(&mut self) -> Result { - let input = self.editor.readline(&get_prompt()); - let mut message_text = match input { - Ok(text) => { - // Add valid input to history - if let Err(e) = self.editor.add_history_entry(text.as_str()) { - eprintln!("Failed to add to history: {}", e); - } - text - } - Err(e) => { - match e { - rustyline::error::ReadlineError::Interrupted => (), - _ => eprintln!("Input error: {}", e), - } - return Ok(Input { - input_type: InputType::Exit, - content: None, - }); - } - }; - message_text = message_text.trim().to_string(); - - if message_text.eq_ignore_ascii_case("/exit") - || message_text.eq_ignore_ascii_case("/quit") - || message_text.eq_ignore_ascii_case("exit") - || message_text.eq_ignore_ascii_case("quit") - { - Ok(Input { - input_type: InputType::Exit, - content: None, - }) - } else if message_text.eq_ignore_ascii_case("/t") { - self.theme = match self.theme { - Theme::Light => { - println!("Switching to Dark theme"); - Theme::Dark - } - Theme::Dark => { - println!("Switching to Ansi theme"); - Theme::Ansi - } - Theme::Ansi => { - println!("Switching to Light theme"); - Theme::Light - } - }; - return Ok(Input { - input_type: InputType::AskAgain, - content: None, - }); - } else if message_text.eq_ignore_ascii_case("/?") - || message_text.eq_ignore_ascii_case("/help") - { - println!("Commands:"); - println!("/exit - Exit the session"); - println!("/t - Toggle Light/Dark/Ansi theme"); - println!("/? | /help - Display this help message"); - println!("Ctrl+C - Interrupt goose (resets the interaction to before the interrupted user request)"); - println!("Ctrl+j - Adds a newline"); - println!("Use Up/Down arrow keys to navigate through command history"); - return Ok(Input { - input_type: InputType::AskAgain, - content: None, - }); - } else { - return Ok(Input { - input_type: InputType::Message, - content: Some(message_text.to_string()), - }); - } - } - - fn load_user_message_history(&mut self, messages: Vec) { - for message in messages.into_iter().filter(|m| m.role == Role::User) { - for content in message.content { - if let Some(text) = content.as_text() { - if let Err(e) = self.editor.add_history_entry(text) { - eprintln!("Failed to add to history: {}", e); - } - } - } - } - } - - fn close(&self) { - // No cleanup required - } -} diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs deleted file mode 100644 index c48e67be3..000000000 --- a/crates/goose-cli/src/session.rs +++ /dev/null @@ -1,362 +0,0 @@ -use anyhow::Result; -use core::panic; -use etcetera::{choose_app_strategy, AppStrategy}; -use futures::StreamExt; -use std::fs::{self, File}; -use std::io::{self, BufRead, Write}; -use std::path::PathBuf; - -use crate::log_usage::log_usage; -use crate::prompt::{InputType, Prompt}; -use goose::agents::Agent; -use goose::message::{Message, MessageContent}; -use mcp_core::handler::ToolError; -use mcp_core::role::Role; - -// File management functions -pub fn ensure_session_dir() -> Result { - // choose_app_strategy().data_dir() - // - macOS/Linux: ~/.local/share/goose/sessions/ - // - Windows: ~\AppData\Roaming\Block\goose\data\sessions - let config_dir = choose_app_strategy(crate::APP_STRATEGY.clone()) - .expect("goose requires a home dir") - .in_data_dir("sessions"); - - if !config_dir.exists() { - fs::create_dir_all(&config_dir)?; - } - - Ok(config_dir) -} - -/// LEGACY NOTE: remove this once old paths are no longer needed. -pub fn legacy_session_dir() -> Option { - // legacy path was in the config dir ~/.config/goose/sessions/ - // ignore errors if we can't re-create the legacy session dir - choose_app_strategy(crate::APP_STRATEGY.clone()) - .map(|strategy| strategy.in_config_dir("sessions")) - .ok() -} - -pub fn get_most_recent_session() -> Result { - let session_dir = ensure_session_dir()?; - let mut entries = fs::read_dir(&session_dir)? - .filter_map(|entry| entry.ok()) - .filter(|entry| entry.path().extension().is_some_and(|ext| ext == "jsonl")) - .collect::>(); - - // LEGACY NOTE: remove this once old paths are no longer needed. - if entries.is_empty() { - if let Some(old_dir) = legacy_session_dir() { - // okay to return the error via ?, since that means we have no sessions in the - // new location, and this old location doesn't exist, so a new session will be created - let old_entries = fs::read_dir(&old_dir)? - .filter_map(|entry| entry.ok()) - .filter(|entry| entry.path().extension().is_some_and(|ext| ext == "jsonl")) - .collect::>(); - entries.extend(old_entries); - } - } - - if entries.is_empty() { - return Err(anyhow::anyhow!("No session files found")); - } - - // Sort by modification time, most recent first - entries.sort_by(|a, b| { - b.metadata() - .and_then(|m| m.modified()) - .unwrap_or(std::time::SystemTime::UNIX_EPOCH) - .cmp( - &a.metadata() - .and_then(|m| m.modified()) - .unwrap_or(std::time::SystemTime::UNIX_EPOCH), - ) - }); - - Ok(entries[0].path()) -} - -pub fn readable_session_file(session_file: &PathBuf) -> Result { - match fs::OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(false) - .open(session_file) - { - Ok(file) => Ok(file), - Err(e) => Err(anyhow::anyhow!("Failed to open session file: {}", e)), - } -} - -pub fn persist_messages(session_file: &PathBuf, messages: &[Message]) -> Result<()> { - let file = fs::File::create(session_file)?; // Create or truncate the file - persist_messages_internal(file, messages) -} - -fn persist_messages_internal(session_file: File, messages: &[Message]) -> Result<()> { - let mut writer = std::io::BufWriter::new(session_file); - - for message in messages { - serde_json::to_writer(&mut writer, &message)?; - writeln!(writer)?; - } - - writer.flush()?; - Ok(()) -} - -pub fn deserialize_messages(file: File) -> Result> { - let reader = io::BufReader::new(file); - let mut messages = Vec::new(); - - for line in reader.lines() { - messages.push(serde_json::from_str::(&line?)?); - } - - Ok(messages) -} - -// Session management -pub struct Session<'a> { - agent: Box, - prompt: Box, - session_file: PathBuf, - messages: Vec, -} - -#[allow(dead_code)] -impl<'a> Session<'a> { - pub fn new( - agent: Box, - mut prompt: Box, - session_file: PathBuf, - ) -> Self { - let messages = match readable_session_file(&session_file) { - Ok(file) => deserialize_messages(file).unwrap_or_else(|e| { - eprintln!( - "Failed to read messages from session file. Starting fresh.\n{}", - e - ); - Vec::::new() - }), - Err(e) => { - eprintln!("Failed to load session file. Starting fresh.\n{}", e); - Vec::::new() - } - }; - - prompt.load_user_message_history(messages.clone()); - - Session { - agent, - prompt, - session_file, - messages, - } - } - - pub async fn start(&mut self) -> Result<(), Box> { - self.prompt.goose_ready(); - - loop { - let input = self.prompt.get_input().unwrap(); - match input.input_type { - InputType::Message => { - if let Some(content) = &input.content { - if content.is_empty() { - continue; - } - self.messages.push(Message::user().with_text(content)); - persist_messages(&self.session_file, &self.messages)?; - } - } - InputType::Exit => break, - InputType::AskAgain => continue, - } - - self.prompt.show_busy(); - self.agent_process_messages().await; - self.prompt.hide_busy(); - } - self.close_session().await; - Ok(()) - } - - pub async fn headless_start( - &mut self, - initial_message: String, - ) -> Result<(), Box> { - self.messages - .push(Message::user().with_text(initial_message.as_str())); - persist_messages(&self.session_file, &self.messages)?; - - self.agent_process_messages().await; - - self.close_session().await; - Ok(()) - } - - async fn agent_process_messages(&mut self) { - let mut stream = match self.agent.reply(&self.messages).await { - Ok(stream) => stream, - Err(e) => { - eprintln!("Error starting reply stream: {}", e); - return; - } - }; - loop { - tokio::select! { - response = stream.next() => { - match response { - Some(Ok(message)) => { - self.messages.push(message.clone()); - persist_messages(&self.session_file, &self.messages).unwrap_or_else(|e| eprintln!("Failed to persist messages: {}", e)); - self.prompt.hide_busy(); - self.prompt.render(Box::new(message.clone())); - self.prompt.show_busy(); - } - Some(Err(e)) => { - eprintln!("Error: {}", e); - drop(stream); - self.rewind_messages(); - self.prompt.render(raw_message(r#" -The error above was an exception we were not able to handle.\n\n -These errors are often related to connection or authentication\n -We've removed the conversation up to the most recent user message - - depending on the error you may be able to continue"#)); - break; - } - None => break, - } - } - _ = tokio::signal::ctrl_c() => { - // Kill any running processes when the client disconnects - // TODO is this used? I suspect post MCP this is on the server instead - // goose::process_store::kill_processes(); - drop(stream); - self.handle_interrupted_messages(); - break; - } - } - } - } - - /// Rewind the messages to before the last user message (they have cancelled it). - fn rewind_messages(&mut self) { - if self.messages.is_empty() { - return; - } - - // Remove messages until we find the last user 'Text' message (not a tool response). - while let Some(message) = self.messages.last() { - if message.role == Role::User - && message - .content - .iter() - .any(|c| matches!(c, MessageContent::Text(_))) - { - break; - } - self.messages.pop(); - } - - // Remove the last user text message we found. - if !self.messages.is_empty() { - self.messages.pop(); - } - } - - fn handle_interrupted_messages(&mut self) { - // First, get any tool requests from the last message if it exists - let tool_requests = self - .messages - .last() - .filter(|msg| msg.role == Role::Assistant) - .map_or(Vec::new(), |msg| { - msg.content - .iter() - .filter_map(|content| { - if let MessageContent::ToolRequest(req) = content { - Some((req.id.clone(), req.tool_call.clone())) - } else { - None - } - }) - .collect() - }); - - if !tool_requests.is_empty() { - // Interrupted during a tool request - // Create tool responses for all interrupted tool requests - let mut response_message = Message::user(); - let last_tool_name = tool_requests - .last() - .and_then(|(_, tool_call)| tool_call.as_ref().ok().map(|tool| tool.name.clone())) - .unwrap_or_else(|| "tool".to_string()); - - for (req_id, _) in &tool_requests { - response_message.content.push(MessageContent::tool_response( - req_id.clone(), - Err(ToolError::ExecutionError( - "Interrupted by the user to make a correction".to_string(), - )), - )); - } - self.messages.push(response_message); - - let prompt_response = &format!( - "We interrupted the existing call to {}. How would you like to proceed?", - last_tool_name - ); - self.messages - .push(Message::assistant().with_text(prompt_response)); - self.prompt.render(raw_message(prompt_response)); - } else { - // An interruption occurred outside of a tool request-response. - if let Some(last_msg) = self.messages.last() { - if last_msg.role == Role::User { - match last_msg.content.first() { - Some(MessageContent::ToolResponse(_)) => { - // Interruption occurred after a tool had completed but not assistant reply - let prompt_response = "We interrupted the existing calls to tools. How would you like to proceed?"; - self.messages - .push(Message::assistant().with_text(prompt_response)); - self.prompt.render(raw_message(prompt_response)); - } - Some(_) => { - // A real users message - self.messages.pop(); - let prompt_response = "We interrupted before the model replied and removed the last message."; - self.prompt.render(raw_message(prompt_response)); - } - None => panic!("No content in last message"), - } - } - } - } - } - - async fn close_session(&mut self) { - let usage = self.agent.usage().await; - log_usage(self.session_file.to_string_lossy().to_string(), usage); - - self.prompt.render(raw_message( - format!( - "Closing session. Recorded to {}\n", - self.session_file.display() - ) - .as_str(), - )); - self.prompt.close(); - } - - pub fn session_file(&self) -> PathBuf { - self.session_file.clone() - } -} - -fn raw_message(content: &str) -> Box { - Box::new(Message::assistant().with_text(content)) -} diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs new file mode 100644 index 000000000..1e61f1ce4 --- /dev/null +++ b/crates/goose-cli/src/session/builder.rs @@ -0,0 +1,140 @@ +use console::style; +use goose::agents::extension::ExtensionError; +use goose::agents::AgentFactory; +use goose::config::{Config, ExtensionManager}; +use mcp_client::transport::Error as McpClientError; +use std::path::PathBuf; +use std::process; + +use super::output; +use super::storage; +use super::Session; + +pub async fn build_session( + name: Option, + resume: bool, + extensions: Vec, + builtins: Vec, +) -> Session { + // Load config and get provider/model + let config = Config::global(); + + let provider_name: String = config + .get("GOOSE_PROVIDER") + .expect("No provider configured. Run 'goose configure' first"); + let session_dir = storage::ensure_session_dir().expect("Failed to create session directory"); + + let model: String = config + .get("GOOSE_MODEL") + .expect("No model configured. Run 'goose configure' first"); + let model_config = goose::model::ModelConfig::new(model.clone()); + let provider = + goose::providers::create(&provider_name, model_config).expect("Failed to create provider"); + + // Create the agent + let agent_version: Option = config.get("GOOSE_AGENT").ok(); + let mut agent = match agent_version { + Some(version) => AgentFactory::create(&version, provider), + None => AgentFactory::create(AgentFactory::default_version(), provider), + } + .expect("Failed to create agent"); + + // Setup extensions for the agent + for extension in ExtensionManager::get_all().expect("should load extensions") { + if extension.enabled { + let config = extension.config.clone(); + agent + .add_extension(config.clone()) + .await + .unwrap_or_else(|e| { + let err = match e { + ExtensionError::Transport(McpClientError::StdioProcessError(inner)) => { + inner + } + _ => e.to_string(), + }; + println!("Failed to start extension: {}, {:?}", config.name(), err); + println!( + "Please check extension configuration for {}.", + config.name() + ); + process::exit(1); + }); + } + } + + // Handle session file resolution and resuming + let session_file = if resume { + if let Some(ref session_name) = name { + // Try to resume specific named session + let session_file = session_dir.join(format!("{}.jsonl", session_name)); + if !session_file.exists() { + output::render_error(&format!( + "Cannot resume session {} - no such session exists", + style(session_name).cyan() + )); + process::exit(1); + } + session_file + } else { + // Try to resume most recent session + match storage::get_most_recent_session() { + Ok(file) => file, + Err(_) => { + output::render_error("Cannot resume - no previous sessions found"); + process::exit(1); + } + } + } + } else { + // Create new session with provided or generated name + let session_name = name.unwrap_or_else(generate_session_name); + create_new_session_file(&session_dir, &session_name) + }; + + // Create new session + let mut session = Session::new(agent, session_file.clone()); + + // Add extensions if provided + for extension_str in extensions { + if let Err(e) = session.add_extension(extension_str).await { + eprintln!("Failed to start extension: {}", e); + process::exit(1); + } + } + + // Add builtin extensions + for builtin in builtins { + if let Err(e) = session.add_builtin(builtin).await { + eprintln!("Failed to start builtin extension: {}", e); + process::exit(1); + } + } + + // Add CLI-specific system prompt extension + session + .agent + .extend_system_prompt(super::prompt::get_cli_prompt()) + .await; + + output::display_session_info(resume, &provider_name, &model, &session_file); + session +} + +fn generate_session_name() -> String { + use rand::{distributions::Alphanumeric, Rng}; + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(8) + .map(char::from) + .collect() +} + +fn create_new_session_file(session_dir: &std::path::Path, name: &str) -> PathBuf { + let session_file = session_dir.join(format!("{}.jsonl", name)); + if session_file.exists() { + eprintln!("Session '{}' already exists", name); + process::exit(1); + } + session_file +} diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs new file mode 100644 index 000000000..7cfa94d35 --- /dev/null +++ b/crates/goose-cli/src/session/input.rs @@ -0,0 +1,152 @@ +use anyhow::Result; +use rustyline::Editor; + +#[derive(Debug)] +pub enum InputResult { + Message(String), + Exit, + AddExtension(String), + AddBuiltin(String), + ToggleTheme, + Retry, +} + +pub fn get_input( + editor: &mut Editor<(), rustyline::history::DefaultHistory>, +) -> Result { + // Ensure Ctrl-J binding is set for newlines + editor.bind_sequence( + rustyline::KeyEvent(rustyline::KeyCode::Char('j'), rustyline::Modifiers::CTRL), + rustyline::EventHandler::Simple(rustyline::Cmd::Newline), + ); + + let prompt = format!("{} ", console::style("( O)>").cyan().bold()); + let input = match editor.readline(&prompt) { + Ok(text) => text, + Err(e) => match e { + rustyline::error::ReadlineError::Interrupted => return Ok(InputResult::Exit), + _ => return Err(e.into()), + }, + }; + + // Add valid input to history + if !input.trim().is_empty() { + editor.add_history_entry(input.as_str())?; + } + + // Handle non-slash commands first + if !input.starts_with('/') { + if input.eq_ignore_ascii_case("exit") || input.eq_ignore_ascii_case("quit") { + return Ok(InputResult::Exit); + } + return Ok(InputResult::Message(input.trim().to_string())); + } + + // Handle slash commands + match handle_slash_command(&input) { + Some(result) => Ok(result), + None => Ok(InputResult::Message(input.trim().to_string())), + } +} + +fn handle_slash_command(input: &str) -> Option { + let input = input.trim(); + + match input { + "/exit" | "/quit" => Some(InputResult::Exit), + "/?" | "/help" => { + print_help(); + Some(InputResult::Retry) + } + "/t" => Some(InputResult::ToggleTheme), + s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), + s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), + _ => None, + } +} + +fn print_help() { + println!( + "Available commands: +/exit or /quit - Exit the session +/t - Toggle Light/Dark/Ansi theme +/extension - Add a stdio extension (format: ENV1=val1 command args...) +/builtin - Add builtin extensions by name (comma-separated) +/? or /help - Display this help message + +Navigation: +Ctrl+C - Interrupt goose (resets the interaction to before the interrupted user request) +Ctrl+J - Add a newline +Up/Down arrows - Navigate through command history" + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_handle_slash_command() { + // Test exit commands + assert!(matches!( + handle_slash_command("/exit"), + Some(InputResult::Exit) + )); + assert!(matches!( + handle_slash_command("/quit"), + Some(InputResult::Exit) + )); + + // Test help commands + assert!(matches!( + handle_slash_command("/help"), + Some(InputResult::Retry) + )); + assert!(matches!( + handle_slash_command("/?"), + Some(InputResult::Retry) + )); + + // Test theme toggle + assert!(matches!( + handle_slash_command("/t"), + Some(InputResult::ToggleTheme) + )); + + // Test extension command + if let Some(InputResult::AddExtension(cmd)) = handle_slash_command("/extension foo bar") { + assert_eq!(cmd, "foo bar"); + } else { + panic!("Expected AddExtension"); + } + + // Test builtin command + if let Some(InputResult::AddBuiltin(names)) = handle_slash_command("/builtin dev,git") { + assert_eq!(names, "dev,git"); + } else { + panic!("Expected AddBuiltin"); + } + + // Test unknown commands + assert!(handle_slash_command("/unknown").is_none()); + } + + // Test whitespace handling + #[test] + fn test_whitespace_handling() { + // Leading/trailing whitespace in extension command + if let Some(InputResult::AddExtension(cmd)) = handle_slash_command(" /extension foo bar ") + { + assert_eq!(cmd, "foo bar"); + } else { + panic!("Expected AddExtension"); + } + + // Leading/trailing whitespace in builtin command + if let Some(InputResult::AddBuiltin(names)) = handle_slash_command(" /builtin dev,git ") { + assert_eq!(names, "dev,git"); + } else { + panic!("Expected AddBuiltin"); + } + } +} diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs new file mode 100644 index 000000000..2a65ef058 --- /dev/null +++ b/crates/goose-cli/src/session/mod.rs @@ -0,0 +1,302 @@ +mod builder; +mod input; +mod output; +mod prompt; +mod storage; +mod thinking; + +pub use builder::build_session; + +use anyhow::Result; +use goose::agents::extension::{Envs, ExtensionConfig}; +use goose::agents::Agent; +use goose::message::{Message, MessageContent}; +use mcp_core::handler::ToolError; +use rand::{distributions::Alphanumeric, Rng}; +use std::path::PathBuf; +use tokio; + +use crate::log_usage::log_usage; + +pub struct Session { + agent: Box, + messages: Vec, + session_file: PathBuf, +} + +impl Session { + pub fn new(agent: Box, session_file: PathBuf) -> Self { + let messages = match storage::read_messages(&session_file) { + Ok(msgs) => msgs, + Err(e) => { + eprintln!("Warning: Failed to load message history: {}", e); + Vec::new() + } + }; + + Session { + agent, + messages, + session_file, + } + } + + /// Add a stdio extension to the session + /// + /// # Arguments + /// * `extension_command` - Full command string including environment variables + /// Format: "ENV1=val1 ENV2=val2 command args..." + pub async fn add_extension(&mut self, extension_command: String) -> Result<()> { + let mut parts: Vec<&str> = extension_command.split_whitespace().collect(); + let mut envs = std::collections::HashMap::new(); + + // Parse environment variables (format: KEY=value) + while let Some(part) = parts.first() { + if !part.contains('=') { + break; + } + let env_part = parts.remove(0); + let (key, value) = env_part.split_once('=').unwrap(); + envs.insert(key.to_string(), value.to_string()); + } + + if parts.is_empty() { + return Err(anyhow::anyhow!("No command provided in extension string")); + } + + let cmd = parts.remove(0).to_string(); + // Generate a random name for the ephemeral extension + let name: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(8) + .map(char::from) + .collect(); + + let config = ExtensionConfig::Stdio { + name, + cmd, + args: parts.iter().map(|s| s.to_string()).collect(), + envs: Envs::new(envs), + }; + + self.agent + .add_extension(config) + .await + .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e)) + } + + /// Add a builtin extension to the session + /// + /// # Arguments + /// * `builtin_name` - Name of the builtin extension(s), comma separated + pub async fn add_builtin(&mut self, builtin_name: String) -> Result<()> { + for name in builtin_name.split(',') { + let config = ExtensionConfig::Builtin { + name: name.trim().to_string(), + }; + self.agent + .add_extension(config) + .await + .map_err(|e| anyhow::anyhow!("Failed to start builtin extension: {}", e))?; + } + Ok(()) + } + + pub async fn start(&mut self) -> Result<()> { + let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; + + // Load history from messages + for msg in self + .messages + .iter() + .filter(|m| m.role == mcp_core::role::Role::User) + { + for content in msg.content.iter() { + if let Some(text) = content.as_text() { + if let Err(e) = editor.add_history_entry(text) { + eprintln!("Warning: Failed to add history entry: {}", e); + } + } + } + } + + output::display_greeting(); + loop { + match input::get_input(&mut editor)? { + input::InputResult::Message(content) => { + self.messages.push(Message::user().with_text(&content)); + storage::persist_messages(&self.session_file, &self.messages)?; + + output::show_thinking(); + self.process_agent_response().await?; + output::hide_thinking(); + } + input::InputResult::Exit => break, + input::InputResult::AddExtension(cmd) => { + match self.add_extension(cmd.clone()).await { + Ok(_) => output::render_extension_success(&cmd), + Err(e) => output::render_extension_error(&cmd, &e.to_string()), + } + } + input::InputResult::AddBuiltin(names) => { + match self.add_builtin(names.clone()).await { + Ok(_) => output::render_builtin_success(&names), + Err(e) => output::render_builtin_error(&names, &e.to_string()), + } + } + input::InputResult::ToggleTheme => { + let current = output::get_theme(); + let new_theme = match current { + output::Theme::Light => { + println!("Switching to Dark theme"); + output::Theme::Dark + } + output::Theme::Dark => { + println!("Switching to Ansi theme"); + output::Theme::Ansi + } + output::Theme::Ansi => { + println!("Switching to Light theme"); + output::Theme::Light + } + }; + output::set_theme(new_theme); + continue; + } + input::InputResult::Retry => continue, + } + } + + // Log usage and cleanup + let usage = self.agent.usage().await; + log_usage(self.session_file.to_string_lossy().to_string(), usage); + println!( + "\nClosing session. Recorded to {}", + self.session_file.display() + ); + Ok(()) + } + + pub async fn headless_start(&mut self, initial_message: String) -> Result<()> { + self.messages + .push(Message::user().with_text(&initial_message)); + storage::persist_messages(&self.session_file, &self.messages)?; + self.process_agent_response().await?; + Ok(()) + } + + async fn process_agent_response(&mut self) -> Result<()> { + let mut stream = self.agent.reply(&self.messages).await?; + + use futures::StreamExt; + loop { + tokio::select! { + result = stream.next() => { + match result { + Some(Ok(message)) => { + self.messages.push(message.clone()); + storage::persist_messages(&self.session_file, &self.messages)?; + output::hide_thinking(); + output::render_message(&message); + output::show_thinking(); + } + Some(Err(e)) => { + eprintln!("Error: {}", e); + drop(stream); + self.handle_interrupted_messages(false); + output::render_error( + "The error above was an exception we were not able to handle.\n\ + These errors are often related to connection or authentication\n\ + We've removed the conversation up to the most recent user message\n\ + - depending on the error you may be able to continue", + ); + break; + } + None => break, + } + } + _ = tokio::signal::ctrl_c() => { + drop(stream); + self.handle_interrupted_messages(true); + break; + } + } + } + Ok(()) + } + + fn handle_interrupted_messages(&mut self, interrupt: bool) { + // First, get any tool requests from the last message if it exists + let tool_requests = self + .messages + .last() + .filter(|msg| msg.role == mcp_core::role::Role::Assistant) + .map_or(Vec::new(), |msg| { + msg.content + .iter() + .filter_map(|content| { + if let MessageContent::ToolRequest(req) = content { + Some((req.id.clone(), req.tool_call.clone())) + } else { + None + } + }) + .collect() + }); + + if !tool_requests.is_empty() { + // Interrupted during a tool request + // Create tool responses for all interrupted tool requests + let mut response_message = Message::user(); + let last_tool_name = tool_requests + .last() + .and_then(|(_, tool_call)| tool_call.as_ref().ok().map(|tool| tool.name.clone())) + .unwrap_or_else(|| "tool".to_string()); + + let notification = if interrupt { + "Interrupted by the user to make a correction".to_string() + } else { + "An uncaught error happened during tool use".to_string() + }; + for (req_id, _) in &tool_requests { + response_message.content.push(MessageContent::tool_response( + req_id.clone(), + Err(ToolError::ExecutionError(notification.clone())), + )); + } + self.messages.push(response_message); + + let prompt = format!( + "The existing call to {} was interrupted. How would you like to proceed?", + last_tool_name + ); + self.messages.push(Message::assistant().with_text(&prompt)); + output::render_message(&Message::assistant().with_text(&prompt)); + } else { + // An interruption occurred outside of a tool request-response. + if let Some(last_msg) = self.messages.last() { + if last_msg.role == mcp_core::role::Role::User { + match last_msg.content.first() { + Some(MessageContent::ToolResponse(_)) => { + // Interruption occurred after a tool had completed but not assistant reply + let prompt = "The tool calling loop was interrupted. How would you like to proceed?"; + self.messages.push(Message::assistant().with_text(prompt)); + output::render_message(&Message::assistant().with_text(prompt)); + } + Some(_) => { + // A real users message + self.messages.pop(); + let prompt = "Interrupted before the model replied and removed the last message."; + output::render_message(&Message::assistant().with_text(prompt)); + } + None => panic!("No content in last message"), + } + } + } + } + } + + pub fn session_file(&self) -> PathBuf { + self.session_file.clone() + } +} diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs new file mode 100644 index 000000000..0ac010efb --- /dev/null +++ b/crates/goose-cli/src/session/output.rs @@ -0,0 +1,457 @@ +use bat::WrappingMode; +use console::style; +use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use mcp_core::tool::ToolCall; +use serde_json::Value; +use std::cell::RefCell; +use std::path::Path; + +// Re-export theme for use in main +#[derive(Clone, Copy)] +pub enum Theme { + Light, + Dark, + Ansi, +} + +impl Theme { + fn as_str(&self) -> &'static str { + match self { + Theme::Light => "GitHub", + Theme::Dark => "zenburn", + Theme::Ansi => "base16", + } + } +} + +thread_local! { + static CURRENT_THEME: RefCell = RefCell::new( + std::env::var("GOOSE_CLI_THEME") + .ok() + .map(|val| { + if val.eq_ignore_ascii_case("light") { + Theme::Light + } else if val.eq_ignore_ascii_case("ansi") { + Theme::Ansi + } else { + Theme::Dark + } + }) + .unwrap_or(Theme::Dark) + ); +} + +pub fn set_theme(theme: Theme) { + CURRENT_THEME.with(|t| *t.borrow_mut() = theme); +} + +pub fn get_theme() -> Theme { + CURRENT_THEME.with(|t| *t.borrow()) +} + +// Simple wrapper around spinner to manage its state +#[derive(Default)] +pub struct ThinkingIndicator { + spinner: Option, +} + +impl ThinkingIndicator { + pub fn show(&mut self) { + let spinner = cliclack::spinner(); + spinner.start(format!( + "{}...", + super::thinking::get_random_thinking_message() + )); + self.spinner = Some(spinner); + } + + pub fn hide(&mut self) { + if let Some(spinner) = self.spinner.take() { + spinner.stop(""); + } + } +} + +// Global thinking indicator +thread_local! { + static THINKING: RefCell = RefCell::new(ThinkingIndicator::default()); +} + +pub fn show_thinking() { + THINKING.with(|t| t.borrow_mut().show()); +} + +pub fn hide_thinking() { + THINKING.with(|t| t.borrow_mut().hide()); +} + +pub fn render_message(message: &Message) { + let theme = get_theme(); + + for content in &message.content { + match content { + MessageContent::Text(text) => print_markdown(&text.text, theme), + MessageContent::ToolRequest(req) => render_tool_request(req, theme), + MessageContent::ToolResponse(resp) => render_tool_response(resp, theme), + MessageContent::Image(image) => { + println!("Image: [data: {}, type: {}]", image.data, image.mime_type); + } + } + } + println!(); +} + +fn render_tool_request(req: &ToolRequest, theme: Theme) { + match &req.tool_call { + Ok(call) => match call.name.as_str() { + "developer__text_editor" => render_text_editor_request(call), + "developer__shell" => render_shell_request(call), + _ => render_default_request(call), + }, + Err(e) => print_markdown(&e.to_string(), theme), + } +} + +fn render_tool_response(resp: &ToolResponse, theme: Theme) { + match &resp.tool_result { + Ok(contents) => { + for content in contents { + if let Some(audience) = content.audience() { + if !audience.contains(&mcp_core::role::Role::User) { + continue; + } + } + + let min_priority = std::env::var("GOOSE_CLI_MIN_PRIORITY") + .ok() + .and_then(|val| val.parse::().ok()) + .unwrap_or(0.0); + + if content + .priority() + .is_some_and(|priority| priority <= min_priority) + || content.priority().is_none() + { + continue; + } + + if let mcp_core::content::Content::Text(text) = content { + print_markdown(&text.text, theme); + } + } + } + Err(e) => print_markdown(&e.to_string(), theme), + } +} + +pub fn render_error(message: &str) { + println!("\n {} {}\n", style("error:").red().bold(), message); +} + +pub fn render_extension_success(name: &str) { + println!(); + println!( + " {} extension `{}`", + style("added").green(), + style(name).cyan(), + ); + println!(); +} + +pub fn render_extension_error(name: &str, error: &str) { + println!(); + println!( + " {} to add extension {}", + style("failed").red(), + style(name).red() + ); + println!(); + println!("{}", style(error).dim()); + println!(); +} + +pub fn render_builtin_success(names: &str) { + println!(); + println!( + " {} builtin{}: {}", + style("added").green(), + if names.contains(',') { "s" } else { "" }, + style(names).cyan() + ); + println!(); +} + +pub fn render_builtin_error(names: &str, error: &str) { + println!(); + println!( + " {} to add builtin{}: {}", + style("failed").red(), + if names.contains(',') { "s" } else { "" }, + style(names).red() + ); + println!(); + println!("{}", style(error).dim()); + println!(); +} + +fn render_text_editor_request(call: &ToolCall) { + print_tool_header(call); + + // Print path first with special formatting + if let Some(Value::String(path)) = call.arguments.get("path") { + println!( + "{}: {}", + style("path").dim(), + style(shorten_path(path)).green() + ); + } + + // Print other arguments normally, excluding path + if let Some(args) = call.arguments.as_object() { + let mut other_args = serde_json::Map::new(); + for (k, v) in args { + if k != "path" { + other_args.insert(k.clone(), v.clone()); + } + } + print_params(&Value::Object(other_args), 0); + } + println!(); +} + +fn render_shell_request(call: &ToolCall) { + print_tool_header(call); + + match call.arguments.get("command") { + Some(Value::String(s)) => { + println!("{}: {}", style("command").dim(), style(s).green()); + } + _ => print_params(&call.arguments, 0), + } + println!(); +} + +fn render_default_request(call: &ToolCall) { + print_tool_header(call); + print_params(&call.arguments, 0); + println!(); +} + +// Helper functions + +fn print_tool_header(call: &ToolCall) { + let parts: Vec<_> = call.name.rsplit("__").collect(); + let tool_header = format!( + "─── {} | {} ──────────────────────────", + style(parts.first().unwrap_or(&"unknown")), + style( + parts + .split_first() + .map(|(_, s)| s.iter().rev().copied().collect::>().join("__")) + .unwrap_or_else(|| "unknown".to_string()) + ) + .magenta() + .dim(), + ); + println!(); + println!("{}", tool_header); +} + +fn print_markdown(content: &str, theme: Theme) { + bat::PrettyPrinter::new() + .input(bat::Input::from_bytes(content.as_bytes())) + .theme(theme.as_str()) + .language("Markdown") + .wrapping_mode(WrappingMode::Character) + .print() + .unwrap(); +} + +const MAX_STRING_LENGTH: usize = 40; +const INDENT: &str = " "; + +fn print_params(value: &Value, depth: usize) { + let indent = INDENT.repeat(depth); + + match value { + Value::Object(map) => { + for (key, val) in map { + match val { + Value::Object(_) => { + println!("{}{}:", indent, style(key).dim()); + print_params(val, depth + 1); + } + Value::Array(arr) => { + println!("{}{}:", indent, style(key).dim()); + for item in arr.iter() { + println!("{}{}- ", indent, INDENT); + print_params(item, depth + 2); + } + } + Value::String(s) => { + if s.len() > MAX_STRING_LENGTH { + println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); + } else { + println!("{}{}: {}", indent, style(key).dim(), style(s).green()); + } + } + Value::Number(n) => { + println!("{}{}: {}", indent, style(key).dim(), style(n).blue()); + } + Value::Bool(b) => { + println!("{}{}: {}", indent, style(key).dim(), style(b).blue()); + } + Value::Null => { + println!("{}{}: {}", indent, style(key).dim(), style("null").dim()); + } + } + } + } + Value::Array(arr) => { + for (i, item) in arr.iter().enumerate() { + println!("{}{}.", indent, i + 1); + print_params(item, depth + 1); + } + } + Value::String(s) => { + if s.len() > MAX_STRING_LENGTH { + println!( + "{}{}", + indent, + style(format!("[REDACTED: {} chars]", s.len())).yellow() + ); + } else { + println!("{}{}", indent, style(s).green()); + } + } + Value::Number(n) => { + println!("{}{}", indent, style(n).yellow()); + } + Value::Bool(b) => { + println!("{}{}", indent, style(b).yellow()); + } + Value::Null => { + println!("{}{}", indent, style("null").dim()); + } + } +} + +fn shorten_path(path: &str) -> String { + let path = Path::new(path); + + // First try to convert to ~ if it's in home directory + let home = etcetera::home_dir().ok(); + let path_str = if let Some(home) = home { + if let Ok(stripped) = path.strip_prefix(home) { + format!("~/{}", stripped.display()) + } else { + path.display().to_string() + } + } else { + path.display().to_string() + }; + + // If path is already short enough, return as is + if path_str.len() <= 60 { + return path_str; + } + + let parts: Vec<_> = path_str.split('/').collect(); + + // If we have 3 or fewer parts, return as is + if parts.len() <= 3 { + return path_str; + } + + // Keep the first component (empty string before root / or ~) and last two components intact + let mut shortened = vec![parts[0].to_string()]; + + // Shorten middle components to their first letter + for component in &parts[1..parts.len() - 2] { + if !component.is_empty() { + shortened.push(component.chars().next().unwrap_or('?').to_string()); + } + } + + // Add the last two components + shortened.push(parts[parts.len() - 2].to_string()); + shortened.push(parts[parts.len() - 1].to_string()); + + shortened.join("/") +} + +// Session display functions +pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) { + let start_session_msg = if resume { + "resuming session |" + } else { + "starting session |" + }; + println!( + "{} {} {} {} {}", + style(start_session_msg).dim(), + style("provider:").dim(), + style(provider).cyan().dim(), + style("model:").dim(), + style(model).cyan().dim(), + ); + println!( + " {} {}", + style("logging to").dim(), + style(session_file.display()).dim().cyan(), + ); +} + +pub fn display_greeting() { + println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n"); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + #[test] + fn test_short_paths_unchanged() { + assert_eq!(shorten_path("/usr/bin"), "/usr/bin"); + assert_eq!(shorten_path("/a/b/c"), "/a/b/c"); + assert_eq!(shorten_path("file.txt"), "file.txt"); + } + + #[test] + fn test_home_directory_conversion() { + // Save the current home dir + let original_home = env::var("HOME").ok(); + + // Set a test home directory + env::set_var("HOME", "/Users/testuser"); + + assert_eq!( + shorten_path("/Users/testuser/documents/file.txt"), + "~/documents/file.txt" + ); + + // A path that starts similarly to home but isn't in home + assert_eq!( + shorten_path("/Users/testuser2/documents/file.txt"), + "/Users/testuser2/documents/file.txt" + ); + + // Restore the original home dir + if let Some(home) = original_home { + env::set_var("HOME", home); + } else { + env::remove_var("HOME"); + } + } + + #[test] + fn test_long_path_shortening() { + assert_eq!( + shorten_path( + "/vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/long/path/with/many/components/file.txt" + ), + "/v/l/p/w/m/components/file.txt" + ); + } +} diff --git a/crates/goose-cli/src/cli_prompt.rs b/crates/goose-cli/src/session/prompt.rs similarity index 100% rename from crates/goose-cli/src/cli_prompt.rs rename to crates/goose-cli/src/session/prompt.rs diff --git a/crates/goose-cli/src/session/storage.rs b/crates/goose-cli/src/session/storage.rs new file mode 100644 index 000000000..baf502ca3 --- /dev/null +++ b/crates/goose-cli/src/session/storage.rs @@ -0,0 +1,165 @@ +use anyhow::Result; +use etcetera::{choose_app_strategy, AppStrategy}; +use goose::message::Message; +use std::fs::{self, File}; +use std::io::{self, BufRead, Write}; +use std::path::{Path, PathBuf}; + +/// Ensure the session directory exists and return its path +pub fn ensure_session_dir() -> Result { + let data_dir = choose_app_strategy(crate::APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .data_dir() + .join("sessions"); + + if !data_dir.exists() { + fs::create_dir_all(&data_dir)?; + } + + Ok(data_dir) +} + +/// Get the path to the most recently modified session file +pub fn get_most_recent_session() -> Result { + let session_dir = ensure_session_dir()?; + let mut entries = fs::read_dir(&session_dir)? + .filter_map(|entry| entry.ok()) + .filter(|entry| entry.path().extension().is_some_and(|ext| ext == "jsonl")) + .collect::>(); + + if entries.is_empty() { + return Err(anyhow::anyhow!("No session files found")); + } + + // Sort by modification time, most recent first + entries.sort_by(|a, b| { + b.metadata() + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH) + .cmp( + &a.metadata() + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH), + ) + }); + + Ok(entries[0].path()) +} + +/// Read messages from a session file +/// +/// Creates the file if it doesn't exist, reads and deserializes all messages if it does. +pub fn read_messages(session_file: &Path) -> Result> { + let file = fs::OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(session_file)?; + + let reader = io::BufReader::new(file); + let mut messages = Vec::new(); + + for line in reader.lines() { + messages.push(serde_json::from_str::(&line?)?); + } + + Ok(messages) +} + +/// Write messages to a session file +/// +/// Overwrites the file with all messages in JSONL format. +pub fn persist_messages(session_file: &Path, messages: &[Message]) -> Result<()> { + let file = File::create(session_file)?; + let mut writer = io::BufWriter::new(file); + + for message in messages { + serde_json::to_writer(&mut writer, &message)?; + writeln!(writer)?; + } + + writer.flush()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use goose::message::MessageContent; + use tempfile::tempdir; + + #[test] + fn test_read_write_messages() -> Result<()> { + let dir = tempdir()?; + let file_path = dir.path().join("test.jsonl"); + + // Create some test messages + let messages = vec![ + Message::user().with_text("Hello"), + Message::assistant().with_text("Hi there"), + ]; + + // Write messages + persist_messages(&file_path, &messages)?; + + // Read them back + let read_messages = read_messages(&file_path)?; + + // Compare + assert_eq!(messages.len(), read_messages.len()); + for (orig, read) in messages.iter().zip(read_messages.iter()) { + assert_eq!(orig.role, read.role); + assert_eq!(orig.content.len(), read.content.len()); + + // Compare first text content + if let (Some(MessageContent::Text(orig_text)), Some(MessageContent::Text(read_text))) = + (orig.content.first(), read.content.first()) + { + assert_eq!(orig_text.text, read_text.text); + } else { + panic!("Messages don't match expected structure"); + } + } + + Ok(()) + } + + #[test] + fn test_empty_file() -> Result<()> { + let dir = tempdir()?; + let file_path = dir.path().join("empty.jsonl"); + + // Reading an empty file should return empty vec + let messages = read_messages(&file_path)?; + assert!(messages.is_empty()); + + Ok(()) + } + + #[test] + fn test_get_most_recent() -> Result<()> { + let dir = tempdir()?; + let base_path = dir.path().join("sessions"); + fs::create_dir_all(&base_path)?; + + // Create a few session files with different timestamps + let old_file = base_path.join("old.jsonl"); + let new_file = base_path.join("new.jsonl"); + + // Create files with some delay to ensure different timestamps + fs::write(&old_file, "dummy content")?; + std::thread::sleep(std::time::Duration::from_secs(1)); + fs::write(&new_file, "dummy content")?; + + // Override the home directory for testing + // This is a bit hacky but works for testing + std::env::set_var("HOME", dir.path()); + + if let Ok(most_recent) = get_most_recent_session() { + assert_eq!(most_recent.file_name().unwrap(), "new.jsonl"); + } + + Ok(()) + } +} diff --git a/crates/goose-cli/src/prompt/thinking.rs b/crates/goose-cli/src/session/thinking.rs similarity index 98% rename from crates/goose-cli/src/prompt/thinking.rs rename to crates/goose-cli/src/session/thinking.rs index 9be90b20f..0368218e8 100644 --- a/crates/goose-cli/src/prompt/thinking.rs +++ b/crates/goose-cli/src/session/thinking.rs @@ -1,10 +1,7 @@ use rand::seq::SliceRandom; /// Extended list of playful thinking messages including both goose and general AI actions -pub const THINKING_MESSAGES: &[&str] = &[ - "Thinking", - "Thinking hard", - // Include all goose actions +const THINKING_MESSAGES: &[&str] = &[ "Spreading wings", "Honking thoughtfully", "Waddling to conclusions", @@ -45,7 +42,6 @@ pub const THINKING_MESSAGES: &[&str] = &[ "Honking success signals", "Waddling through workflows", "Nesting in neural networks", - // AI thinking actions "Consulting the digital oracle", "Summoning binary spirits", "Reticulating splines", diff --git a/crates/goose-cli/src/test_helpers.rs b/crates/goose-cli/src/test_helpers.rs deleted file mode 100644 index fa1804ccf..000000000 --- a/crates/goose-cli/src/test_helpers.rs +++ /dev/null @@ -1,70 +0,0 @@ -/// Helper function to set up a temporary home directory for testing, returns path of that temp dir. -/// Also creates a default profiles.json to avoid obscure test failures when there are no profiles. -#[cfg(test)] -pub fn run_with_tmp_dir T, T>(func: F) -> T { - use std::ffi::OsStr; - use tempfile::tempdir; - - let temp_dir = tempdir().unwrap(); - let temp_dir_path = temp_dir.path().to_path_buf(); - setup_profile(&temp_dir_path, None); - - temp_env::with_vars( - [ - ("HOME", Some(temp_dir_path.as_os_str())), - ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), - ], - func, - ) -} - -#[cfg(test)] -#[allow(dead_code)] -pub async fn run_with_tmp_dir_async(func: F) -> T -where - F: FnOnce() -> Fut, - Fut: std::future::Future, -{ - use std::ffi::OsStr; - use tempfile::tempdir; - - let temp_dir = tempdir().unwrap(); - let temp_dir_path = temp_dir.path().to_path_buf(); - setup_profile(&temp_dir_path, None); - - temp_env::async_with_vars( - [ - ("HOME", Some(temp_dir_path.as_os_str())), - ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), - ], - func(), - ) - .await -} - -#[cfg(test)] -use std::path::Path; - -#[cfg(test)] -/// Setup a goose profile for testing, and an optional profile string -fn setup_profile(temp_dir_path: &Path, profile_string: Option<&str>) { - use std::fs; - - let profile_path = temp_dir_path - .join(".config") - .join("goose") - .join("profiles.json"); - fs::create_dir_all(profile_path.parent().unwrap()).unwrap(); - let default_profile = r#" -{ - "profile_items": { - "default": { - "provider": "databricks", - "model": "goose", - "additional_extensions": [] - } - } -}"#; - - fs::write(&profile_path, profile_string.unwrap_or(default_profile)).unwrap(); -} From 255c849e5ebc89999310b6677741416cd565070a Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 11 Feb 2025 22:42:57 -0800 Subject: [PATCH 3/5] Improve rg tool use. (#1188) --- crates/goose-mcp/src/developer/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index bbc7c5f8c..492f4b56a 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -139,8 +139,8 @@ impl DeveloperRouter { **Important**: Use ripgrep - `rg` - when you need to locate a file or a code reference, other solutions may show ignored or hidden files. For example *do not* use `find` or `ls -r` - - To locate a file by name: `rg --files | rg example.py` - - To locate content inside files: `rg 'class Example'` + - List files by name: `rg --files | rg ` + - List files that contain a regex: `rg '' -l` "#}, }; From 5e8a8bae3d19d7defc139cc41c027c730a2c6b61 Mon Sep 17 00:00:00 2001 From: TechnoHouse <13776377+deephbz@users.noreply.github.com> Date: Wed, 12 Feb 2025 15:01:32 +0800 Subject: [PATCH 4/5] Support modifying AZURE_OPENAI_API_VERSION (#1042) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jean-FrançoisMillet --- crates/goose/src/providers/azure.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index db826a9d4..a35259fbb 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -15,7 +15,7 @@ use mcp_core::tool::Tool; pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o"; pub const AZURE_DOC_URL: &str = "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; -pub const AZURE_API_VERSION: &str = "2024-10-21"; +pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21"; pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; #[derive(Debug, serde::Serialize)] @@ -25,6 +25,7 @@ pub struct AzureProvider { endpoint: String, api_key: String, deployment_name: String, + api_version: String, model: ModelConfig, } @@ -41,6 +42,9 @@ impl AzureProvider { let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?; let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?; let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?; + let api_version: String = config + .get("AZURE_OPENAI_API_VERSION") + .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); let client = Client::builder() .timeout(Duration::from_secs(600)) @@ -51,6 +55,7 @@ impl AzureProvider { endpoint, api_key, deployment_name, + api_version, model, }) } @@ -63,7 +68,7 @@ impl AzureProvider { "openai/deployments/{}/chat/completions", self.deployment_name )); - base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION))); + base_url.set_query(Some(&format!("api-version={}", self.api_version))); let response: reqwest::Response = self .client @@ -99,6 +104,12 @@ impl Provider for AzureProvider { false, Some("Name of your Azure OpenAI deployment"), ), + ConfigKey::new( + "AZURE_OPENAI_API_VERSION", + false, + false, + Some("Azure OpenAI API version, default: 2024-10-21"), + ), ], ) } From 80b694d6c609175e061fa02faa71262a0cfd20a7 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 11 Feb 2025 23:20:27 -0800 Subject: [PATCH 5/5] fix: validate function call json schemas for openai (#1156) Co-authored-by: angiejones --- crates/goose/src/providers/formats/openai.rs | 130 +++++++++++++++++- documentation/docs/tutorials/jetbrains-mcp.md | 4 - 2 files changed, 129 insertions(+), 5 deletions(-) diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 528c3f232..20b5aaaae 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -253,6 +253,55 @@ pub fn get_usage(data: &Value) -> Result { Ok(Usage::new(input_tokens, output_tokens, total_tokens)) } +/// Validates and fixes tool schemas to ensure they have proper parameter structure. +/// If parameters exist, ensures they have properties and required fields, or removes parameters entirely. +pub fn validate_tool_schemas(tools: &mut [Value]) { + for tool in tools.iter_mut() { + if let Some(function) = tool.get_mut("function") { + if let Some(parameters) = function.get_mut("parameters") { + if parameters.is_object() { + ensure_valid_json_schema(parameters); + } + } + } + } +} + +/// Ensures that the given JSON value follows the expected JSON Schema structure. +fn ensure_valid_json_schema(schema: &mut Value) { + if let Some(params_obj) = schema.as_object_mut() { + // Check if this is meant to be an object type schema + let is_object_type = params_obj + .get("type") + .and_then(|t| t.as_str()) + .map_or(true, |t| t == "object"); // Default to true if no type is specified + + // Only apply full schema validation to object types + if is_object_type { + // Ensure required fields exist with default values + params_obj.entry("properties").or_insert_with(|| json!({})); + params_obj.entry("required").or_insert_with(|| json!([])); + params_obj.entry("type").or_insert_with(|| json!("object")); + + // Recursively validate properties if it exists + if let Some(properties) = params_obj.get_mut("properties") { + if let Some(properties_obj) = properties.as_object_mut() { + for (_key, prop) in properties_obj.iter_mut() { + if prop.is_object() + && prop + .get("type") + .and_then(|t| t.as_str()) + .map_or(false, |t| t == "object") + { + ensure_valid_json_schema(prop); + } + } + } + } + } + } +} + pub fn create_request( model_config: &ModelConfig, system: &str, @@ -275,12 +324,15 @@ pub fn create_request( }); let messages_spec = format_messages(messages, image_format); - let tools_spec = if !tools.is_empty() { + let mut tools_spec = if !tools.is_empty() { format_tools(tools)? } else { vec![] }; + // Validate tool schemas + validate_tool_schemas(&mut tools_spec); + let mut messages_array = vec![system_message]; messages_array.extend(messages_spec); @@ -326,6 +378,82 @@ mod tests { use mcp_core::content::Content; use serde_json::json; + #[test] + fn test_validate_tool_schemas() { + // Test case 1: Empty parameters object + // Input JSON with an incomplete parameters object + let mut actual = vec![json!({ + "type": "function", + "function": { + "name": "test_func", + "description": "test description", + "parameters": { + "type": "object" + } + } + })]; + + // Run the function to validate and update schemas + validate_tool_schemas(&mut actual); + + // Expected JSON after validation + let expected = vec![json!({ + "type": "function", + "function": { + "name": "test_func", + "description": "test description", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + })]; + + // Compare entire JSON structures instead of individual fields + assert_eq!(actual, expected); + + // Test case 2: Missing type field + let mut tools = vec![json!({ + "type": "function", + "function": { + "name": "test_func", + "description": "test description", + "parameters": { + "properties": {} + } + } + })]; + + validate_tool_schemas(&mut tools); + + let params = tools[0]["function"]["parameters"].as_object().unwrap(); + assert_eq!(params["type"], "object"); + + // Test case 3: Complete valid schema should remain unchanged + let original_schema = json!({ + "type": "function", + "function": { + "name": "test_func", + "description": "test description", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City and country" + } + }, + "required": ["location"] + } + } + }); + + let mut tools = vec![original_schema.clone()]; + validate_tool_schemas(&mut tools); + assert_eq!(tools[0], original_schema); + } + const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ "choices": [{ "role": "assistant", diff --git a/documentation/docs/tutorials/jetbrains-mcp.md b/documentation/docs/tutorials/jetbrains-mcp.md index 2ae08830e..feb41fcc0 100644 --- a/documentation/docs/tutorials/jetbrains-mcp.md +++ b/documentation/docs/tutorials/jetbrains-mcp.md @@ -13,10 +13,6 @@ The JetBrains extension is designed to work within your IDE. Goose can accomplis This tutorial covers how to enable and use the JetBrains MCP Server as a built-in Goose extension to integrate with any JetBrains IDE. -:::warning Known Limitation -The JetBrains extension [does not work](https://github.com/block/goose/issues/933) with OpenAI models (e.g. gpt-4o). -::: - ## Configuration 1. Add the [MCP Server plugin](https://plugins.jetbrains.com/plugin/26071-mcp-server) to your IDE.