From 5822b143f0231c7a8b738a2de468ec6b63ad3211 Mon Sep 17 00:00:00 2001 From: Dev-Khant Date: Tue, 22 Oct 2024 12:36:25 +0530 Subject: [PATCH] formatting --- mem0/client/main.py | 22 +++++------ mem0/configs/base.py | 4 +- mem0/embeddings/gemini.py | 5 ++- mem0/embeddings/huggingface.py | 4 +- mem0/llms/gemini.py | 69 +++++++++++++++++++--------------- mem0/llms/openai.py | 4 +- mem0/memory/main.py | 24 +++++++++--- mem0/proxy/main.py | 4 +- mem0/vector_stores/milvus.py | 7 +--- 9 files changed, 83 insertions(+), 60 deletions(-) diff --git a/mem0/client/main.py b/mem0/client/main.py index 5b4c9161d5..ef90183f07 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -56,12 +56,12 @@ class MemoryClient: """ def __init__( - self, - api_key: Optional[str] = None, - host: Optional[str] = None, - organization: Optional[str] = None, - project: Optional[str] = None - ): + self, + api_key: Optional[str] = None, + host: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + ): """Initialize the MemoryClient. Args: @@ -275,9 +275,7 @@ def delete_users(self) -> Dict[str, str]: params = {"org_name": self.organization, "project_name": self.project} entities = self.users() for entity in entities["results"]: - response = self.client.delete( - f"/v1/entities/{entity['type']}/{entity['id']}/", params=params - ) + response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params) response.raise_for_status() capture_client_event("client.delete_users", self) @@ -372,7 +370,7 @@ def __init__( api_key: Optional[str] = None, host: Optional[str] = None, organization: Optional[str] = None, - project: Optional[str] = None + project: Optional[str] = None, ): self.sync_client = MemoryClient(api_key, host, organization, project) self.async_client = httpx.AsyncClient( @@ -410,7 +408,9 @@ async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: elif version == "v2": response = await self.async_client.post(f"/{version}/memories/", json=params) response.raise_for_status() - capture_client_event("async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}) + capture_client_event( + "async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)} + ) return response.json() @api_error_handler diff --git a/mem0/configs/base.py b/mem0/configs/base.py index c9293c2503..55d6b2e94f 100644 --- a/mem0/configs/base.py +++ b/mem0/configs/base.py @@ -73,4 +73,6 @@ class AzureConfig(BaseModel): azure_deployment: str = Field(description="The name of the Azure deployment.", default=None) azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None) api_version: str = Field(description="The version of the Azure API being used.", default=None) - default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None) + default_headers: Optional[Dict[str, str]] = Field( + description="Headers to include in requests to the Azure API.", default=None + ) diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py index 7ef429a968..210848e39a 100644 --- a/mem0/embeddings/gemini.py +++ b/mem0/embeddings/gemini.py @@ -1,5 +1,6 @@ import os from typing import Optional + import google.generativeai as genai from mem0.configs.embeddings.base import BaseEmbedderConfig @@ -9,7 +10,7 @@ class GoogleGenAIEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) - + self.config.model = self.config.model or "models/text-embedding-004" self.config.embedding_dims = self.config.embedding_dims or 768 @@ -27,4 +28,4 @@ def embed(self, text): """ text = text.replace("\n", " ") response = genai.embed_content(model=self.config.model, content=text) - return response['embedding'] \ No newline at end of file + return response["embedding"] diff --git a/mem0/embeddings/huggingface.py b/mem0/embeddings/huggingface.py index d2bf5b8263..b8641cacab 100644 --- a/mem0/embeddings/huggingface.py +++ b/mem0/embeddings/huggingface.py @@ -14,7 +14,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs) - self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() + self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() def embed(self, text): """ @@ -26,4 +26,4 @@ def embed(self, text): Returns: list: The embedding vector. """ - return self.model.encode(text, convert_to_numpy = True).tolist() + return self.model.encode(text, convert_to_numpy=True).tolist() diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index a475226cbe..7fdf5e4e4d 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -6,7 +6,9 @@ from google.generativeai import GenerativeModel from google.generativeai.types import content_types except ImportError: - raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.") + raise ImportError( + "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." + ) from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase @@ -44,8 +46,8 @@ def _parse_response(self, response, tools): if fn := part.function_call: processed_response["tool_calls"].append( { - "name": fn.name, - "arguments": {key:val for key, val in fn.args.items()}, + "name": fn.name, + "arguments": {key: val for key, val in fn.args.items()}, } ) @@ -53,7 +55,7 @@ def _parse_response(self, response, tools): else: return response.candidates[0].content.parts[0].text - def _reformat_messages(self, messages : List[Dict[str, str]]): + def _reformat_messages(self, messages: List[Dict[str, str]]): """ Reformat messages for Gemini. @@ -71,9 +73,8 @@ def _reformat_messages(self, messages : List[Dict[str, str]]): else: content = message["content"] - - new_messages.append({"parts": content, - "role": "model" if message["role"] == "model" else "user"}) + + new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"}) return new_messages @@ -89,24 +90,24 @@ def _reformat_tools(self, tools: Optional[List[Dict]]): """ def remove_additional_properties(data): - """Recursively removes 'additionalProperties' from nested dictionaries.""" - - if isinstance(data, dict): - filtered_dict = { - key: remove_additional_properties(value) - for key, value in data.items() - if not (key == "additionalProperties") - } - return filtered_dict - else: - return data - + """Recursively removes 'additionalProperties' from nested dictionaries.""" + + if isinstance(data, dict): + filtered_dict = { + key: remove_additional_properties(value) + for key, value in data.items() + if not (key == "additionalProperties") + } + return filtered_dict + else: + return data + new_tools = [] if tools: for tool in tools: - func = tool['function'].copy() - new_tools.append({"function_declarations":[remove_additional_properties(func)]}) - + func = tool["function"].copy() + new_tools.append({"function_declarations": [remove_additional_properties(func)]}) + return new_tools else: return None @@ -142,13 +143,21 @@ def generate_response( params["response_schema"] = list[response_format] if tool_choice: tool_config = content_types.to_tool_config( - {"function_calling_config": - {"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None} - }) - - response = self.client.generate_content(contents = self._reformat_messages(messages), - tools = self._reformat_tools(tools), - generation_config = genai.GenerationConfig(**params), - tool_config = tool_config) + { + "function_calling_config": { + "mode": tool_choice, + "allowed_function_names": [tool["function"]["name"] for tool in tools] + if tool_choice == "any" + else None, + } + } + ) + + response = self.client.generate_content( + contents=self._reformat_messages(messages), + tools=self._reformat_tools(tools), + generation_config=genai.GenerationConfig(**params), + tool_config=tool_config, + ) return self._parse_response(response, tools) diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index c585162f29..a9c302f8ce 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -18,7 +18,9 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter self.client = OpenAI( api_key=os.environ.get("OPENROUTER_API_KEY"), - base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", + base_url=self.config.openrouter_base_url + or os.getenv("OPENROUTER_API_BASE") + or "https://openrouter.ai/api/v1", ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") diff --git a/mem0/memory/main.py b/mem0/memory/main.py index c8c3c229bf..c4e7b21247 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -4,7 +4,6 @@ import logging import uuid import warnings -from collections import defaultdict from datetime import datetime from typing import Any, Dict @@ -186,7 +185,9 @@ def _add_to_vector_store(self, messages, metadata, filters): logging.info(resp) try: if resp["event"] == "ADD": - memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) + memory_id = self._create_memory( + data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata + ) returned_memories.append( { "id": memory_id, @@ -195,7 +196,12 @@ def _add_to_vector_store(self, messages, metadata, filters): } ) elif resp["event"] == "UPDATE": - self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) + self._update_memory( + memory_id=resp["id"], + data=resp["text"], + existing_embeddings=new_message_embeddings, + metadata=metadata, + ) returned_memories.append( { "id": resp["id"], @@ -304,10 +310,14 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_graph_entities = ( - executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None + executor.submit(self.graph.get_all, filters, limit) + if self.version == "v1.1" and self.enable_graph + else None ) - concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) + concurrent.futures.wait( + [future_memories, future_graph_entities] if future_graph_entities else [future_memories] + ) all_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None @@ -399,7 +409,9 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil else None ) - concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) + concurrent.futures.wait( + [future_memories, future_graph_entities] if future_graph_entities else [future_memories] + ) original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None diff --git a/mem0/proxy/main.py b/mem0/proxy/main.py index d1db29b260..f52177d0fd 100644 --- a/mem0/proxy/main.py +++ b/mem0/proxy/main.py @@ -181,9 +181,9 @@ def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, def _format_query_with_memories(self, messages, relevant_memories): # Check if self.mem0_client is an instance of Memory or MemoryClient - + if isinstance(self.mem0_client, mem0.memory.main.Memory): - memories_text = "\n".join(memory["memory"] for memory in relevant_memories['results']) + memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"]) elif isinstance(self.mem0_client, mem0.client.main.MemoryClient): memories_text = "\n".join(memory["memory"] for memory in relevant_memories) return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}" diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py index a2fb8002e0..013fc0e3ec 100644 --- a/mem0/vector_stores/milvus.py +++ b/mem0/vector_stores/milvus.py @@ -76,11 +76,8 @@ def create_col( schema = CollectionSchema(fields, enable_dynamic_field=True) index = self.client.prepare_index_params( - field_name="vectors", - metric_type=metric_type, - index_type="AUTOINDEX", - index_name="vector_index" - ) + field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index" + ) self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index) def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):