diff --git a/mem0/proxy/main.py b/mem0/proxy/main.py index 8689ef3e1c..1fecc7b6a1 100644 --- a/mem0/proxy/main.py +++ b/mem0/proxy/main.py @@ -17,7 +17,9 @@ subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"]) import litellm except subprocess.CalledProcessError: - print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.") + print( + "Failed to install 'litellm'. Please install it manually using 'pip install litellm'." + ) sys.exit(1) else: raise ImportError("The required 'litellm' library is not installed.") @@ -96,6 +98,8 @@ def create( api_version: Optional[str] = None, api_key: Optional[str] = None, model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # New parameter + add_to_memory: bool = True, ): if not any([user_id, agent_id, run_id]): raise ValueError("One of user_id, agent_id, run_id must be provided") @@ -107,10 +111,17 @@ def create( prepared_messages = self._prepare_messages(messages) if prepared_messages[-1]["role"] == "user": - self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters) - relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit) + if add_to_memory: + self._async_add_to_memory( + messages, user_id, agent_id, run_id, metadata, filters + ) + relevant_memories = self._fetch_relevant_memories( + messages, user_id, agent_id, run_id, filters, limit + ) logger.debug(f"Retrieved {len(relevant_memories)} relevant memories") - prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories) + prepared_messages[-1]["content"] = self._format_query_with_memories( + messages, relevant_memories + ) response = litellm.completion( model=model, @@ -151,7 +162,9 @@ def _prepare_messages(self, messages: List[dict]) -> List[dict]: return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages return messages - def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters): + def _async_add_to_memory( + self, messages, user_id, agent_id, run_id, metadata, filters + ): def add_task(): logger.debug("Adding to memory asynchronously") self.mem0_client.add( @@ -165,9 +178,13 @@ def add_task(): threading.Thread(target=add_task, daemon=True).start() - def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit): + def _fetch_relevant_memories( + self, messages, user_id, agent_id, run_id, filters, limit + ): # Currently, only pass the last 6 messages to the search API to prevent long query - message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:] + message_input = [ + f"{message['role']}: {message['content']}" for message in messages + ][-6:] # TODO: Make it better by summarizing the past conversation return self.mem0_client.search( query="\n".join(message_input), @@ -182,7 +199,9 @@ def _format_query_with_memories(self, messages, relevant_memories): # Check if self.mem0_client is an instance of Memory or MemoryClient if isinstance(self.mem0_client, mem0.memory.main.Memory): - memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"]) + memories_text = "\n".join( + memory["memory"] for memory in relevant_memories["results"] + ) elif isinstance(self.mem0_client, mem0.client.main.MemoryClient): memories_text = "\n".join(memory["memory"] for memory in relevant_memories) return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"