Skip to content

Commit

Permalink
feat: simplify memory management (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
wendytang authored Jan 22, 2025
1 parent bc6fd1a commit 17ca675
Showing 1 changed file with 97 additions and 107 deletions.
204 changes: 97 additions & 107 deletions crates/goose-mcp/src/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use std::{
io::{self, Read, Write},
path::PathBuf,
pin::Pin,
sync::{Arc, Mutex},
};
use tracing::info;
use url::Url;

use mcp_core::{
handler::{ResourceError, ToolError},
Expand All @@ -28,7 +25,6 @@ use mcp_server::Router;
pub struct MemoryRouter {
tools: Vec<Tool>,
instructions: String,
active_resources: Arc<Mutex<HashMap<String, Resource>>>,
global_memory_dir: PathBuf,
local_memory_dir: PathBuf,
}
Expand Down Expand Up @@ -135,10 +131,18 @@ impl MemoryRouter {
- **Search by Category**:
- Provides all memories within the specified context.
- Use: `retrieve_memories(category="development", is_global=False)`
- Note: If you want to retrieve all local memories, use `retrieve_memories(category="*", is_global=False)`
- Note: If you want to retrieve all global memories, use `retrieve_memories(category="*", is_global=True)`
- **Filter by Tags**:
- Enables targeted retrieval based on specific tags.
- Use: Provide tag filters to refine search.
The Protocol is:
To remove a memory, use the following protocol:
- **Remove by Category**:
- Removes all memories within the specified category.
- Use: `remove_memory_category(category="development", is_global=False)`
- Note: If you want to remove all local memories, use `remove_memory_category(category="*", is_global=False)`
- Note: If you want to remove all global memories, use `remove_memory_category(category="*", is_global=True)`
The Protocol is:
1. Confirm what kind of information the user seeks by category or keyword.
2. Suggest categories or relevant tags based on the user's request.
3. Use the retrieve function to access relevant memory entries.
Expand All @@ -160,29 +164,84 @@ impl MemoryRouter {
"#};

// Check for .goose/memory in current directory
let local_memory_dir = std::env::current_dir()
.unwrap()
let local_memory_dir = std::env::var("GOOSE_WORKING_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| std::env::current_dir().unwrap())
.join(".goose")
.join("memory");

// Check for .config/goose/memory in user's home directory
let global_memory_dir = dirs::home_dir()
.map(|home| home.join(".config/goose/memory"))
.unwrap_or_else(|| PathBuf::from(".config/goose/memory"));

fs::create_dir_all(&global_memory_dir).unwrap();
fs::create_dir_all(&local_memory_dir).unwrap();

Self {
let mut memory_router = Self {
tools: vec![
remember_memory,
retrieve_memories,
remove_memory_category,
remove_specific_memory,
],
instructions,
active_resources: Arc::new(Mutex::new(HashMap::new())),
instructions: instructions.clone(),
global_memory_dir,
local_memory_dir,
};

let retrieved_global_memories = memory_router.retrieve_all(true);
let retrieved_local_memories = memory_router.retrieve_all(false);

let mut updated_instructions = instructions;

let memories_follow_up_instructions = formatdoc! {r#"
**Here are the user's currently saved memories:**
Please keep this information in mind when answering future questions.
Do not bring up memories unless relevant.
Note: if the user has not saved any memories, this section will be empty.
Note: if the user removes a memory that was previously loaded into the system, please remove it from the system instructions.
"#};

updated_instructions.push_str("\n\n");
updated_instructions.push_str(&memories_follow_up_instructions);

if let Ok(global_memories) = retrieved_global_memories {
if !global_memories.is_empty() {
updated_instructions.push_str("\n\nGlobal Memories:\n");
for (category, memories) in global_memories {
updated_instructions.push_str(&format!("\nCategory: {}\n", category));
for memory in memories {
updated_instructions.push_str(&format!("- {}\n", memory));
}
}
}
}

if let Ok(local_memories) = retrieved_local_memories {
if !local_memories.is_empty() {
updated_instructions.push_str("\n\nLocal Memories:\n");
for (category, memories) in local_memories {
updated_instructions.push_str(&format!("\nCategory: {}\n", category));
for memory in memories {
updated_instructions.push_str(&format!("- {}\n", memory));
}
}
}
}

memory_router.set_instructions(updated_instructions);

memory_router
}

// Add a setter method for instructions
pub fn set_instructions(&mut self, new_instructions: String) {
self.instructions = new_instructions;
}

pub fn get_instructions(&self) -> &str {
&self.instructions
}

fn get_memory_file(&self, category: &str, is_global: bool) -> PathBuf {
Expand Down Expand Up @@ -227,7 +286,6 @@ impl MemoryRouter {
is_global: bool,
) -> io::Result<()> {
let memory_file_path = self.get_memory_file(category, is_global);
let uri = Url::from_file_path(&memory_file_path).unwrap().to_string();

let mut file = fs::OpenOptions::new()
.append(true)
Expand All @@ -238,10 +296,6 @@ impl MemoryRouter {
}
writeln!(file, "{}\n", data)?;

// Create and store the resource
let resource = Resource::new(uri.clone(), Some("text".to_string()), None).unwrap();
self.add_active_resource(uri, resource);

Ok(())
}

Expand All @@ -255,8 +309,6 @@ impl MemoryRouter {
return Ok(HashMap::new());
}

let uri = Url::from_file_path(&memory_file_path).unwrap().to_string();

let mut file = fs::File::open(memory_file_path)?;
let mut content = String::new();
file.read_to_string(&mut content)?;
Expand All @@ -283,11 +335,6 @@ impl MemoryRouter {
}
}

// Update resource
if let Some(resource) = self.active_resources.lock().unwrap().get_mut(&uri) {
resource.update_timestamp();
}

Ok(memories)
}

Expand All @@ -302,8 +349,6 @@ impl MemoryRouter {
return Ok(());
}

let uri = Url::from_file_path(&memory_file_path).unwrap().to_string();

let mut file = fs::File::open(&memory_file_path)?;
let mut content = String::new();
file.read_to_string(&mut content)?;
Expand All @@ -317,24 +362,25 @@ impl MemoryRouter {

fs::write(memory_file_path, new_content.join("\n\n"))?;

// Update resource
if let Some(resource) = self.active_resources.lock().unwrap().get_mut(&uri) {
resource.update_timestamp();
}

Ok(())
}

pub fn clear_memory(&self, category: &str, is_global: bool) -> io::Result<()> {
let memory_file_path = self.get_memory_file(category, is_global);
let uri = Url::from_file_path(&memory_file_path).unwrap().to_string();
if memory_file_path.exists() {
fs::remove_file(memory_file_path)?;
}

// Remove resource from active resources
self.active_resources.lock().unwrap().remove(&uri);
Ok(())
}

pub fn clear_all_global_or_local_memories(&self, is_global: bool) -> io::Result<()> {
let base_dir = if is_global {
&self.global_memory_dir
} else {
&self.local_memory_dir
};
fs::remove_dir_all(base_dir)?;
Ok(())
}

Expand All @@ -353,13 +399,25 @@ impl MemoryRouter {
}
"retrieve_memories" => {
let args = MemoryArgs::from_value(&tool_call.arguments)?;
let memories = self.retrieve(args.category, args.is_global)?;
let memories = if args.category == "*" {
self.retrieve_all(args.is_global)?
} else {
self.retrieve(args.category, args.is_global)?
};
Ok(format!("Retrieved memories: {:?}", memories))
}
"remove_memory_category" => {
let args = MemoryArgs::from_value(&tool_call.arguments)?;
self.clear_memory(args.category, args.is_global)?;
Ok(format!("Cleared memories in category: {}", args.category))
if args.category == "*" {
self.clear_all_global_or_local_memories(args.is_global)?;
Ok(format!(
"Cleared all memory {} categories",
if args.is_global { "global" } else { "local" }
))
} else {
self.clear_memory(args.category, args.is_global)?;
Ok(format!("Cleared memories in category: {}", args.category))
}
}
"remove_specific_memory" => {
let args = MemoryArgs::from_value(&tool_call.arguments)?;
Expand All @@ -373,55 +431,6 @@ impl MemoryRouter {
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Unknown tool")),
}
}

fn add_active_resource(&self, uri: String, resource: Resource) {
self.active_resources.lock().unwrap().insert(uri, resource);
}

async fn read_memory_resources(&self, uri: &str) -> Result<String, ResourceError> {
// Ensure the resource exists in the active resources map
let active_resources = self.active_resources.lock().unwrap();
let resource = active_resources
.get(uri)
.ok_or_else(|| ResourceError::NotFound(format!("Resource '{}' not found", uri)))?;

let url =
Url::parse(uri).map_err(|e| ResourceError::NotFound(format!("Invalid URI: {}", e)))?;

// Read content based on scheme and mime_type
match url.scheme() {
"file" => {
let path = url
.to_file_path()
.map_err(|_| ResourceError::NotFound("Invalid file path in URI".into()))?;

// Ensure file exists
if !path.exists() {
return Err(ResourceError::NotFound(format!(
"File does not exist: {}",
path.display()
)));
}

match resource.mime_type.as_str() {
"text" => {
// Read the file as UTF-8 text
fs::read_to_string(&path).map_err(|e| {
ResourceError::ExecutionError(format!("Failed to read file: {}", e))
})
}
mime_type => Err(ResourceError::ExecutionError(format!(
"Unsupported mime type: {}",
mime_type
))),
}
}
scheme => Err(ResourceError::NotFound(format!(
"Unsupported URI scheme: {}",
scheme
))),
}
}
}

#[async_trait]
Expand All @@ -435,10 +444,7 @@ impl Router for MemoryRouter {
}

fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.build()
CapabilitiesBuilder::new().with_tools(false).build()
}

fn list_tools(&self) -> Vec<Tool> {
Expand Down Expand Up @@ -466,30 +472,14 @@ impl Router for MemoryRouter {
}

fn list_resources(&self) -> Vec<Resource> {
let resources = self
.active_resources
.lock()
.unwrap()
.values()
.cloned()
.collect();
info!("Listing resources: {:?}", resources);
resources
Vec::new()
}

fn read_resource(
&self,
uri: &str,
_uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
let this = self.clone();
let uri = uri.to_string();
info!("Reading resource: {}", uri);
Box::pin(async move {
match this.read_memory_resources(&uri).await {
Ok(content) => Ok(content),
Err(e) => Err(e),
}
})
Box::pin(async move { Ok("".to_string()) })
}
}

Expand Down

0 comments on commit 17ca675

Please sign in to comment.