Skip to content

Commit

Permalink
feat: Add option to disable asynchronous memory addition in completio…
Browse files Browse the repository at this point in the history
…n.create method
  • Loading branch information
lfcunha07 committed Jan 8, 2025
1 parent c63c0ac commit a7f2fb3
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions mem0/proxy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
import litellm
except subprocess.CalledProcessError:
print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
print(
"Failed to install 'litellm'. Please install it manually using 'pip install litellm'."
)
sys.exit(1)
else:
raise ImportError("The required 'litellm' library is not installed.")
Expand Down Expand Up @@ -96,6 +98,8 @@ def create(
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
# New parameter
add_to_memory: bool = True,
):
if not any([user_id, agent_id, run_id]):
raise ValueError("One of user_id, agent_id, run_id must be provided")
Expand All @@ -107,10 +111,17 @@ def create(

prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
if add_to_memory:
self._async_add_to_memory(
messages, user_id, agent_id, run_id, metadata, filters
)
relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit
)
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories
)

response = litellm.completion(
model=model,
Expand Down Expand Up @@ -151,7 +162,9 @@ def _prepare_messages(self, messages: List[dict]) -> List[dict]:
return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
return messages

def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def _async_add_to_memory(
self, messages, user_id, agent_id, run_id, metadata, filters
):
def add_task():
logger.debug("Adding to memory asynchronously")
self.mem0_client.add(
Expand All @@ -165,9 +178,13 @@ def add_task():

threading.Thread(target=add_task, daemon=True).start()

def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
def _fetch_relevant_memories(
self, messages, user_id, agent_id, run_id, filters, limit
):
# Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
message_input = [
f"{message['role']}: {message['content']}" for message in messages
][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),
Expand All @@ -182,7 +199,9 @@ def _format_query_with_memories(self, messages, relevant_memories):
# Check if self.mem0_client is an instance of Memory or MemoryClient

if isinstance(self.mem0_client, mem0.memory.main.Memory):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
memories_text = "\n".join(
memory["memory"] for memory in relevant_memories["results"]
)
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"

0 comments on commit a7f2fb3

Please sign in to comment.