Skip to content

Commit

Permalink
maintain extension name consistent (#721)
Browse files Browse the repository at this point in the history
Co-authored-by: Bradley Axen <[email protected]>
  • Loading branch information
Kvadratni and baxen authored Jan 24, 2025
1 parent f60d6c9 commit 79fa2af
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 75 deletions.
70 changes: 30 additions & 40 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
style("goose configure").cyan()
);
// Since we are setting up for the first time, we'll also enable the developer system
ExtensionManager::set(
"developer",
ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: "developer".to_string(),
},
ExtensionManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: "developer".to_string(),
},
)?;
})?;
} else {
let _ = config.clear();
println!(
Expand Down Expand Up @@ -267,7 +264,7 @@ pub fn toggle_extensions_dialog() -> Result<(), Box<dyn Error>> {
// Create a list of extension names and their enabled status
let extension_status: Vec<(String, bool)> = extensions
.iter()
.map(|(name, entry)| (name.clone(), entry.enabled))
.map(|entry| (entry.config.name().to_string(), entry.enabled))
.collect();

// Get currently enabled extensions for the selection
Expand Down Expand Up @@ -347,26 +344,23 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
.interact()?
.to_string();

ExtensionManager::set(
&extension,
ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: extension.clone(),
},
ExtensionManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: extension.clone(),
},
)?;
})?;

cliclack::outro(format!("Enabled {} extension", style(extension).green()))?;
}
"stdio" => {
let extensions = ExtensionManager::get_all()?;
let extensions = ExtensionManager::get_all_names()?;
let name: String = cliclack::input("What would you like to call this extension?")
.placeholder("my-extension")
.validate(move |input: &String| {
if input.is_empty() {
Err("Please enter a name")
} else if extensions.contains_key(input) {
} else if extensions.contains(input) {
Err("An extension with this name already exists")
} else {
Ok(())
Expand Down Expand Up @@ -412,28 +406,26 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}
}

ExtensionManager::set(
&name,
ExtensionEntry {
enabled: true,
config: ExtensionConfig::Stdio {
cmd,
args,
envs: Envs::new(envs),
},
ExtensionManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Stdio {
name: name.clone(),
cmd,
args,
envs: Envs::new(envs),
},
)?;
})?;

cliclack::outro(format!("Added {} extension", style(name).green()))?;
}
"sse" => {
let extensions = ExtensionManager::get_all()?;
let extensions = ExtensionManager::get_all_names()?;
let name: String = cliclack::input("What would you like to call this extension?")
.placeholder("my-remote-extension")
.validate(move |input: &String| {
if input.is_empty() {
Err("Please enter a name")
} else if extensions.contains_key(input) {
} else if extensions.contains(input) {
Err("An extension with this name already exists")
} else {
Ok(())
Expand Down Expand Up @@ -476,16 +468,14 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}
}

ExtensionManager::set(
&name,
ExtensionEntry {
enabled: true,
config: ExtensionConfig::Sse {
uri,
envs: Envs::new(envs),
},
ExtensionManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Sse {
name: name.clone(),
uri,
envs: Envs::new(envs),
},
)?;
})?;

cliclack::outro(format!("Added {} extension", style(name).green()))?;
}
Expand Down
19 changes: 15 additions & 4 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ pub async fn build_session(
.expect("Failed to create agent");

// Setup extensions for the agent
for (name, extension) in ExtensionManager::get_all().expect("should load extensions") {
for extension in ExtensionManager::get_all().expect("should load extensions") {
if extension.enabled {
let config = extension.config.clone();
agent
.add_extension(extension.config.clone())
.add_extension(config.clone())
.await
.unwrap_or_else(|e| {
let err = match e {
Expand All @@ -53,8 +54,11 @@ pub async fn build_session(
}
_ => e.to_string(),
};
println!("Failed to start extension: {}, {:?}", name, err);
println!("Please check extension configuration for {}.", name);
println!("Failed to start extension: {}, {:?}", config.name(), err);
println!(
"Please check extension configuration for {}.",
config.name()
);
process::exit(1);
});
}
Expand All @@ -81,7 +85,14 @@ pub async fn build_session(
}

let cmd = parts.remove(0).to_string();
//this is an ephemeral extension so name does not matter
let name = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
.collect();
let config = ExtensionConfig::Stdio {
name,
cmd,
args: parts.iter().map(|s| s.to_string()).collect(),
envs: Envs::new(envs),
Expand Down
13 changes: 12 additions & 1 deletion crates/goose-server/src/routes/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ enum ExtensionConfigRequest {
/// Server-Sent Events (SSE) extension.
#[serde(rename = "sse")]
Sse {
/// The name to identify this extension
name: String,
/// The URI endpoint for the SSE extension.
uri: String,
/// List of environment variable keys. The server will fetch their values from the keyring.
Expand All @@ -24,6 +26,8 @@ enum ExtensionConfigRequest {
/// Standard I/O (stdio) extension.
#[serde(rename = "stdio")]
Stdio {
/// The name to identify this extension
name: String,
/// The command to execute.
cmd: String,
/// Arguments for the command.
Expand Down Expand Up @@ -73,7 +77,11 @@ async fn add_extension(

// Construct ExtensionConfig with Envs populated from keyring based on provided env_keys.
let extension_config: ExtensionConfig = match request {
ExtensionConfigRequest::Sse { uri, env_keys } => {
ExtensionConfigRequest::Sse {
name,
uri,
env_keys,
} => {
let mut env_map = HashMap::new();
for key in env_keys {
match config.get_secret(&key) {
Expand All @@ -97,11 +105,13 @@ async fn add_extension(
}

ExtensionConfig::Sse {
name,
uri,
envs: Envs::new(env_map),
}
}
ExtensionConfigRequest::Stdio {
name,
cmd,
args,
env_keys,
Expand Down Expand Up @@ -129,6 +139,7 @@ async fn add_extension(
}

ExtensionConfig::Stdio {
name,
cmd,
args,
envs: Envs::new(env_map),
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/examples/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn main() {
// Setup an agent with the developer extension
let mut agent = AgentFactory::create("reference", provider).expect("default should exist");

let config = ExtensionConfig::stdio("./target/debug/developer");
let config = ExtensionConfig::stdio("developer", "./target/debug/developer");
agent.add_extension(config).await.unwrap();

println!("Extensions:");
Expand Down
26 changes: 11 additions & 15 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,22 @@ impl Capabilities {
/// Add a new MCP extension based on the provided client type
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
let mut client: Box<dyn McpClientTrait> = match config {
ExtensionConfig::Sse { ref uri, ref envs } => {
let mut client: Box<dyn McpClientTrait> = match &config {
ExtensionConfig::Sse { uri, envs, .. } => {
let transport = SseTransport::new(uri, envs.get_env());
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(300));
Box::new(McpClient::new(service))
}
ExtensionConfig::Stdio {
ref cmd,
ref args,
ref envs,
cmd, args, envs, ..
} => {
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(300));
Box::new(McpClient::new(service))
}
ExtensionConfig::Builtin { ref name } => {
ExtensionConfig::Builtin { name } => {
// For builtin extensions, we run the current executable with mcp and extension name
let cmd = std::env::current_exe()
.expect("should find the current executable")
Expand Down Expand Up @@ -148,18 +146,18 @@ impl Capabilities {
// Store instructions if provided
if let Some(instructions) = init_result.instructions {
self.instructions
.insert(init_result.server_info.name.clone(), instructions);
.insert(config.name().to_string(), instructions);
}

// if the server is capable if resources we track it
if init_result.capabilities.resources.is_some() {
self.resource_capable_extensions
.insert(sanitize(init_result.server_info.name.clone()));
.insert(sanitize(config.name().to_string()));
}

// Store the client
// Store the client using the provided name
self.clients.insert(
sanitize(init_result.server_info.name.clone()),
sanitize(config.name().to_string()),
Arc::new(Mutex::new(client)),
);

Expand All @@ -180,15 +178,13 @@ impl Capabilities {
/// Get aggregated usage statistics
pub async fn remove_extension(&mut self, name: &str) -> ExtensionResult<()> {
self.clients.remove(name);
self.instructions.remove(name);
self.resource_capable_extensions.remove(name);
Ok(())
}

pub async fn list_extensions(&self) -> ExtensionResult<Vec<String>> {
let mut extensions = Vec::new();
for name in self.clients.keys() {
extensions.push(name.clone());
}
Ok(extensions)
Ok(self.clients.keys().cloned().collect())
}

pub async fn get_usage(&self) -> Vec<ProviderUsage> {
Expand Down
37 changes: 30 additions & 7 deletions crates/goose/src/agents/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,28 @@ pub enum ExtensionConfig {
/// Server-sent events client with a URI endpoint
#[serde(rename = "sse")]
Sse {
/// The name used to identify this extension
name: String,
uri: String,
#[serde(default)]
envs: Envs,
},
/// Standard I/O client with command and arguments
#[serde(rename = "stdio")]
Stdio {
/// The name used to identify this extension
name: String,
cmd: String,
args: Vec<String>,
#[serde(default)]
envs: Envs,
},
/// Built-in extension that is part of the goose binary
#[serde(rename = "builtin")]
Builtin { name: String },
Builtin {
/// The name used to identify this extension
name: String,
},
}

impl Default for ExtensionConfig {
Expand All @@ -73,15 +80,17 @@ impl Default for ExtensionConfig {
}

impl ExtensionConfig {
pub fn sse<S: Into<String>>(uri: S) -> Self {
pub fn sse<S: Into<String>>(name: S, uri: S) -> Self {
Self::Sse {
name: name.into(),
uri: uri.into(),
envs: Envs::default(),
}
}

pub fn stdio<S: Into<String>>(cmd: S) -> Self {
pub fn stdio<S: Into<String>>(name: S, cmd: S) -> Self {
Self::Stdio {
name: name.into(),
cmd: cmd.into(),
args: vec![],
envs: Envs::default(),
Expand All @@ -94,22 +103,36 @@ impl ExtensionConfig {
S: Into<String>,
{
match self {
Self::Stdio { cmd, envs, .. } => Self::Stdio {
Self::Stdio {
name, cmd, envs, ..
} => Self::Stdio {
name,
cmd,
envs,
args: args.into_iter().map(Into::into).collect(),
},
other => other,
}
}

/// Get the extension name regardless of variant
pub fn name(&self) -> &str {
match self {
Self::Sse { name, .. } => name,
Self::Stdio { name, .. } => name,
Self::Builtin { name } => name,
}
}
}

impl std::fmt::Display for ExtensionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExtensionConfig::Sse { uri, .. } => write!(f, "SSE({})", uri),
ExtensionConfig::Stdio { cmd, args, .. } => {
write!(f, "Stdio({} {})", cmd, args.join(" "))
ExtensionConfig::Sse { name, uri, .. } => write!(f, "SSE({}: {})", name, uri),
ExtensionConfig::Stdio {
name, cmd, args, ..
} => {
write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
}
ExtensionConfig::Builtin { name } => write!(f, "Builtin({})", name),
}
Expand Down
Loading

0 comments on commit 79fa2af

Please sign in to comment.