Skip to content

Commit

Permalink
keep extension name consistent when inserting and removing from state (
Browse files Browse the repository at this point in the history
  • Loading branch information
kalvinnchau authored Jan 24, 2025
1 parent 79fa2af commit 90d6f1b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 11 additions & 9 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,23 @@ impl Capabilities {
.await
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;

let sanitized_name = sanitize(config.name().to_string());

// Store instructions if provided
if let Some(instructions) = init_result.instructions {
self.instructions
.insert(config.name().to_string(), instructions);
.insert(sanitized_name.clone(), instructions);
}

// if the server is capable if resources we track it
if init_result.capabilities.resources.is_some() {
self.resource_capable_extensions
.insert(sanitize(config.name().to_string()));
.insert(sanitized_name.clone());
}

// Store the client using the provided name
self.clients.insert(
sanitize(config.name().to_string()),
Arc::new(Mutex::new(client)),
);
self.clients
.insert(sanitized_name.clone(), Arc::new(Mutex::new(client)));

Ok(())
}
Expand All @@ -177,9 +177,11 @@ impl Capabilities {

/// Get aggregated usage statistics
pub async fn remove_extension(&mut self, name: &str) -> ExtensionResult<()> {
self.clients.remove(name);
self.instructions.remove(name);
self.resource_capable_extensions.remove(name);
let sanitized_name = sanitize(name.to_string());

self.clients.remove(&sanitized_name);
self.instructions.remove(&sanitized_name);
self.resource_capable_extensions.remove(&sanitized_name);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/config/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl ExtensionManager {
pub fn get_all() -> Result<Vec<ExtensionEntry>> {
let config = Config::global();
let extensions: HashMap<String, ExtensionEntry> =
config.get("extensions").unwrap_or(HashMap::new());
config.get("extensions").unwrap_or_default();
Ok(Vec::from_iter(extensions.values().cloned()))
}

Expand Down

0 comments on commit 90d6f1b

Please sign in to comment.