From 605558da9d610e183ee1315f5f6e4e7a83a2f2a3 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Tue, 29 Oct 2024 11:32:07 +0530 Subject: [PATCH] Code formatting (#1986) --- cookbooks/helper/mem0_teachability.py | 13 ++-- mem0/client/main.py | 46 +++++++++----- mem0/configs/vector_stores/chroma.py | 2 - mem0/memory/main.py | 29 +++++---- mem0/memory/telemetry.py | 2 +- .../test_azure_openai_embeddings.py | 10 +-- tests/embeddings/test_gemini.py | 15 +---- tests/embeddings/test_vertexai_embeddings.py | 31 +++------- tests/llms/test_gemini_llm.py | 61 ++++++++++--------- tests/llms/test_openai.py | 5 +- tests/test_main.py | 6 +- tests/vector_stores/test_chroma.py | 22 ++----- tests/vector_stores/test_qdrant.py | 26 +++----- 13 files changed, 119 insertions(+), 149 deletions(-) diff --git a/cookbooks/helper/mem0_teachability.py b/cookbooks/helper/mem0_teachability.py index df7909bedd..221a4b4f64 100644 --- a/cookbooks/helper/mem0_teachability.py +++ b/cookbooks/helper/mem0_teachability.py @@ -12,7 +12,7 @@ from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent from termcolor import colored from mem0 import Memory -from mem0.configs.base import MemoryConfig + class Mem0Teachability(AgentCapability): def __init__( @@ -60,7 +60,6 @@ def process_last_received_message(self, text: Union[Dict, str]): return expanded_text def _consider_memo_storage(self, comment: Union[Dict, str]): - memo_added = False response = self._analyze( comment, "Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.", @@ -85,8 +84,9 @@ def _consider_memo_storage(self, comment: Union[Dict, str]): if self.verbosity >= 1: print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow")) - self.memory.add([{"role": "user", "content": f"Task: {general_task}\nAdvice: {advice}"}], agent_id=self.agent_id) - memo_added = True + self.memory.add( + [{"role": "user", "content": f"Task: {general_task}\nAdvice: {advice}"}], agent_id=self.agent_id + ) response = self._analyze( comment, @@ -105,8 +105,9 @@ def _consider_memo_storage(self, comment: Union[Dict, str]): if self.verbosity >= 1: print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow")) - self.memory.add([{"role": "user", "content": f"Question: {question}\nAnswer: {answer}"}], agent_id=self.agent_id) - memo_added = True + self.memory.add( + [{"role": "user", "content": f"Question: {question}\nAnswer: {answer}"}], agent_id=self.agent_id + ) def _consider_memo_retrieval(self, comment: Union[Dict, str]): if self.verbosity >= 1: diff --git a/mem0/client/main.py b/mem0/client/main.py index 5194a0643a..1dc1fba539 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -117,7 +117,9 @@ def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, payload = self._prepare_payload(messages, kwargs) response = self.client.post("/v1/memories/", json=payload) response.raise_for_status() - capture_client_event("client.add", self) + if "metadata" in kwargs: + del kwargs["metadata"] + capture_client_event("client.add", self, {"keys": list(kwargs.keys())}) return response.json() @api_error_handler @@ -135,7 +137,7 @@ def get(self, memory_id: str) -> Dict[str, Any]: """ response = self.client.get(f"/v1/memories/{memory_id}/") response.raise_for_status() - capture_client_event("client.get", self) + capture_client_event("client.get", self, {"memory_id": memory_id}) return response.json() @api_error_handler @@ -159,10 +161,12 @@ def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: elif version == "v2": response = self.client.post(f"/{version}/memories/", json=params) response.raise_for_status() + if "metadata" in kwargs: + del kwargs["metadata"] capture_client_event( "client.get_all", self, - {"filters": len(params), "limit": kwargs.get("limit", 100)}, + {"api_version": version, "keys": list(kwargs.keys())}, ) return response.json() @@ -186,7 +190,9 @@ def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, An payload.update({k: v for k, v in kwargs.items() if v is not None}) response = self.client.post(f"/{version}/memories/search/", json=payload) response.raise_for_status() - capture_client_event("client.search", self, {"limit": kwargs.get("limit", 100)}) + if "metadata" in kwargs: + del kwargs["metadata"] + capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys())}) return response.json() @api_error_handler @@ -199,7 +205,7 @@ def update(self, memory_id: str, data: str) -> Dict[str, Any]: Returns: Dict[str, Any]: The response from the server. """ - capture_client_event("client.update", self) + capture_client_event("client.update", self, {"memory_id": memory_id}) response = self.client.put(f"/v1/memories/{memory_id}/", json={"text": data}) response.raise_for_status() return response.json() @@ -219,7 +225,7 @@ def delete(self, memory_id: str) -> Dict[str, Any]: """ response = self.client.delete(f"/v1/memories/{memory_id}/") response.raise_for_status() - capture_client_event("client.delete", self) + capture_client_event("client.delete", self, {"memory_id": memory_id}) return response.json() @api_error_handler @@ -239,7 +245,7 @@ def delete_all(self, **kwargs) -> Dict[str, str]: params = self._prepare_params(kwargs) response = self.client.delete("/v1/memories/", params=params) response.raise_for_status() - capture_client_event("client.delete_all", self, {"params": len(params)}) + capture_client_event("client.delete_all", self, {"keys": list(kwargs.keys())}) return response.json() @api_error_handler @@ -257,7 +263,7 @@ def history(self, memory_id: str) -> List[Dict[str, Any]]: """ response = self.client.get(f"/v1/memories/{memory_id}/history/") response.raise_for_status() - capture_client_event("client.history", self) + capture_client_event("client.history", self, {"memory_id": memory_id}) return response.json() @api_error_handler @@ -390,14 +396,16 @@ async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dic payload = self.sync_client._prepare_payload(messages, kwargs) response = await self.async_client.post("/v1/memories/", json=payload) response.raise_for_status() - capture_client_event("async_client.add", self.sync_client) + if "metadata" in kwargs: + del kwargs["metadata"] + capture_client_event("async_client.add", self.sync_client, {"keys": list(kwargs.keys())}) return response.json() @api_error_handler async def get(self, memory_id: str) -> Dict[str, Any]: response = await self.async_client.get(f"/v1/memories/{memory_id}/") response.raise_for_status() - capture_client_event("async_client.get", self.sync_client) + capture_client_event("async_client.get", self.sync_client, {"memory_id": memory_id}) return response.json() @api_error_handler @@ -408,8 +416,10 @@ async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: elif version == "v2": response = await self.async_client.post(f"/{version}/memories/", json=params) response.raise_for_status() + if "metadata" in kwargs: + del kwargs["metadata"] capture_client_event( - "async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)} + "async_client.get_all", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())} ) return response.json() @@ -419,21 +429,25 @@ async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[s payload.update(self.sync_client._prepare_params(kwargs)) response = await self.async_client.post(f"/{version}/memories/search/", json=payload) response.raise_for_status() - capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)}) + if "metadata" in kwargs: + del kwargs["metadata"] + capture_client_event( + "async_client.search", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())} + ) return response.json() @api_error_handler async def update(self, memory_id: str, data: str) -> Dict[str, Any]: response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data}) response.raise_for_status() - capture_client_event("async_client.update", self.sync_client) + capture_client_event("async_client.update", self.sync_client, {"memory_id": memory_id}) return response.json() @api_error_handler async def delete(self, memory_id: str) -> Dict[str, Any]: response = await self.async_client.delete(f"/v1/memories/{memory_id}/") response.raise_for_status() - capture_client_event("async_client.delete", self.sync_client) + capture_client_event("async_client.delete", self.sync_client, {"memory_id": memory_id}) return response.json() @api_error_handler @@ -441,14 +455,14 @@ async def delete_all(self, **kwargs) -> Dict[str, str]: params = self.sync_client._prepare_params(kwargs) response = await self.async_client.delete("/v1/memories/", params=params) response.raise_for_status() - capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)}) + capture_client_event("async_client.delete_all", self.sync_client, {"keys": list(kwargs.keys())}) return response.json() @api_error_handler async def history(self, memory_id: str) -> List[Dict[str, Any]]: response = await self.async_client.get(f"/v1/memories/{memory_id}/history/") response.raise_for_status() - capture_client_event("async_client.history", self.sync_client) + capture_client_event("async_client.history", self.sync_client, {"memory_id": memory_id}) return response.json() @api_error_handler diff --git a/mem0/configs/vector_stores/chroma.py b/mem0/configs/vector_stores/chroma.py index afff8a8f7b..664807b850 100644 --- a/mem0/configs/vector_stores/chroma.py +++ b/mem0/configs/vector_stores/chroma.py @@ -1,5 +1,3 @@ -import subprocess -import sys from typing import Any, ClassVar, Dict, Optional from pydantic import BaseModel, Field, model_validator diff --git a/mem0/memory/main.py b/mem0/memory/main.py index c4e7b21247..a0d0fbfe3b 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -37,11 +37,11 @@ def __init__(self, config: MemoryConfig = MemoryConfig()): self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config) self.db = SQLiteManager(self.config.history_db_path) self.collection_name = self.config.vector_store.config.collection_name - self.version = self.config.version + self.api_version = self.config.version self.enable_graph = False - if self.version == "v1.1" and self.config.graph_store.config: + if self.api_version == "v1.1" and self.config.graph_store.config: from mem0.memory.graph_memory import MemoryGraph self.graph = MemoryGraph(self.config) @@ -119,7 +119,7 @@ def add( vector_store_result = future1.result() graph_result = future2.result() - if self.version == "v1.1": + if self.api_version == "v1.1": return { "results": vector_store_result, "relations": graph_result, @@ -226,13 +226,13 @@ def _add_to_vector_store(self, messages, metadata, filters): except Exception as e: logging.error(f"Error in new_memories_with_actions: {e}") - capture_event("mem0.add", self) + capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())}) return returned_memories def _add_to_graph(self, messages, filters): added_entities = [] - if self.version == "v1.1" and self.enable_graph: + if self.api_version == "v1.1" and self.enable_graph: if filters["user_id"]: self.graph.user_id = filters["user_id"] elif filters["agent_id"]: @@ -305,13 +305,13 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): if run_id: filters["run_id"] = run_id - capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit}) + capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())}) with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_graph_entities = ( executor.submit(self.graph.get_all, filters, limit) - if self.version == "v1.1" and self.enable_graph + if self.api_version == "v1.1" and self.enable_graph else None ) @@ -322,7 +322,7 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): all_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None - if self.version == "v1.1": + if self.api_version == "v1.1": if self.enable_graph: return {"results": all_memories, "relations": graph_entities} else: @@ -398,14 +398,14 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil capture_event( "mem0.search", self, - {"filters": len(filters), "limit": limit, "version": self.version}, + {"limit": limit, "version": self.api_version, "keys": list(filters.keys())}, ) with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._search_vector_store, query, filters, limit) future_graph_entities = ( executor.submit(self.graph.search, query, filters, limit) - if self.version == "v1.1" and self.enable_graph + if self.api_version == "v1.1" and self.enable_graph else None ) @@ -416,7 +416,7 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None - if self.version == "v1.1": + if self.api_version == "v1.1": if self.enable_graph: return {"results": original_memories, "relations": graph_entities} else: @@ -518,14 +518,14 @@ def delete_all(self, user_id=None, agent_id=None, run_id=None): "At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method." ) - capture_event("mem0.delete_all", self, {"filters": len(filters)}) + capture_event("mem0.delete_all", self, {"keys": list(filters.keys())}) memories = self.vector_store.list(filters=filters)[0] for memory in memories: self._delete_memory(memory.id) logger.info(f"Deleted {len(memories)} memories") - if self.version == "v1.1" and self.enable_graph: + if self.api_version == "v1.1" and self.enable_graph: self.graph.delete_all(filters) return {"message": "Memories deleted successfully!"} @@ -561,6 +561,7 @@ def _create_memory(self, data, existing_embeddings, metadata=None): payloads=[metadata], ) self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"]) + capture_event("mem0._create_memory", self, {"memory_id": memory_id}) return memory_id def _update_memory(self, memory_id, data, existing_embeddings, metadata=None): @@ -603,6 +604,7 @@ def _update_memory(self, memory_id, data, existing_embeddings, metadata=None): created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"], ) + capture_event("mem0._update_memory", self, {"memory_id": memory_id}) return memory_id def _delete_memory(self, memory_id): @@ -611,6 +613,7 @@ def _delete_memory(self, memory_id): prev_value = existing_memory.payload["data"] self.vector_store.delete(vector_id=memory_id) self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1) + capture_event("mem0._delete_memory", self, {"memory_id": memory_id}) return memory_id def reset(self): diff --git a/mem0/memory/telemetry.py b/mem0/memory/telemetry.py index e4e6b86360..f595a458e7 100644 --- a/mem0/memory/telemetry.py +++ b/mem0/memory/telemetry.py @@ -67,7 +67,7 @@ def capture_event(event_name, memory_instance, additional_data=None): "vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}", "llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}", "embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}", - "function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.version}", + "function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.api_version}", } if additional_data: event_data.update(additional_data) diff --git a/tests/embeddings/test_azure_openai_embeddings.py b/tests/embeddings/test_azure_openai_embeddings.py index f674dc2a45..8e7e6b72a4 100644 --- a/tests/embeddings/test_azure_openai_embeddings.py +++ b/tests/embeddings/test_azure_openai_embeddings.py @@ -33,11 +33,7 @@ def test_embed_text(mock_openai_client): @pytest.mark.parametrize( "default_headers, expected_header", - [ - (None, None), - ({"Test": "test_value"}, "test_value"), - ({}, None) - ], + [(None, None), ({"Test": "test_value"}, "test_value"), ({}, None)], ) def test_embed_text_with_default_headers(default_headers, expected_header): config = BaseEmbedderConfig( @@ -47,8 +43,8 @@ def test_embed_text_with_default_headers(default_headers, expected_header): "api_version": "test_version", "azure_endpoint": "test_endpoint", "azuer_deployment": "test_deployment", - "default_headers": default_headers - } + "default_headers": default_headers, + }, ) embedder = AzureOpenAIEmbedding(config) assert embedder.client.api_key == "test" diff --git a/tests/embeddings/test_gemini.py b/tests/embeddings/test_gemini.py index 07691632fa..a1ae50b14f 100644 --- a/tests/embeddings/test_gemini.py +++ b/tests/embeddings/test_gemini.py @@ -12,17 +12,11 @@ def mock_genai(): @pytest.fixture def config(): - return BaseEmbedderConfig( - api_key="dummy_api_key", - model="test_model" - ) + return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model") def test_embed_query(mock_genai, config): - - mock_embedding_response = { - 'embedding': [0.1, 0.2, 0.3, 0.4] - } + mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]} mock_genai.return_value = mock_embedding_response embedder = GoogleGenAIEmbedding(config) @@ -31,7 +25,4 @@ def test_embed_query(mock_genai, config): embedding = embedder.embed(text) assert embedding == [0.1, 0.2, 0.3, 0.4] - mock_genai.assert_called_once_with( - model="test_model", - content="Hello, world!" - ) + mock_genai.assert_called_once_with(model="test_model", content="Hello, world!") diff --git a/tests/embeddings/test_vertexai_embeddings.py b/tests/embeddings/test_vertexai_embeddings.py index 26ac634141..861c84bd15 100644 --- a/tests/embeddings/test_vertexai_embeddings.py +++ b/tests/embeddings/test_vertexai_embeddings.py @@ -1,7 +1,6 @@ import pytest from unittest.mock import Mock, patch from mem0.embeddings.vertexai import VertexAIEmbedding -from mem0.configs.embeddings.base import BaseEmbedderConfig @pytest.fixture @@ -35,15 +34,11 @@ def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_co embedder = VertexAIEmbedding(config) mock_embedding = Mock(values=[0.1, 0.2, 0.3]) - mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [ - mock_embedding - ] + mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding] - result = embedder.embed("Hello world") + embedder.embed("Hello world") - mock_text_embedding_model.from_pretrained.assert_called_once_with( - "text-embedding-004" - ) + mock_text_embedding_model.from_pretrained.assert_called_once_with("text-embedding-004") mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with( texts=["Hello world"], output_dimensionality=256 ) @@ -60,15 +55,11 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con embedder = VertexAIEmbedding(config) mock_embedding = Mock(values=[0.4, 0.5, 0.6]) - mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [ - mock_embedding - ] + mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding] result = embedder.embed("Test embedding") - mock_text_embedding_model.from_pretrained.assert_called_with( - "custom-embedding-model" - ) + mock_text_embedding_model.from_pretrained.assert_called_with("custom-embedding-model") mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with( texts=["Test embedding"], output_dimensionality=512 ) @@ -93,16 +84,12 @@ def test_missing_credentials(mock_os, mock_text_embedding_model, mock_config): config = mock_config() - with pytest.raises( - ValueError, match="Google application credentials JSON is not provided" - ): + with pytest.raises(ValueError, match="Google application credentials JSON is not provided"): VertexAIEmbedding(config) @patch("mem0.embeddings.vertexai.TextEmbeddingModel") -def test_embed_with_different_dimensions( - mock_text_embedding_model, mock_os_environ, mock_config -): +def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_environ, mock_config): mock_config.vertex_credentials_json = "/path/to/credentials.json" mock_config.return_value.embedding_dims = 1024 @@ -110,9 +97,7 @@ def test_embed_with_different_dimensions( embedder = VertexAIEmbedding(config) mock_embedding = Mock(values=[0.1] * 1024) - mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [ - mock_embedding - ] + mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding] result = embedder.embed("Large embedding test") diff --git a/tests/llms/test_gemini_llm.py b/tests/llms/test_gemini_llm.py index c6244aeb24..ffdec4fba3 100644 --- a/tests/llms/test_gemini_llm.py +++ b/tests/llms/test_gemini_llm.py @@ -33,19 +33,19 @@ def test_generate_response_without_tools(mock_gemini_client: Mock): response = llm.generate_response(messages) mock_gemini_client.generate_content.assert_called_once_with( - contents = [ - {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, - {"parts": "Hello, how are you?", "role": "user"} + contents=[ + {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, + {"parts": "Hello, how are you?", "role": "user"}, ], - generation_config = GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), - tools = None, - tool_config = content_types.to_tool_config( - {"function_calling_config": - {"mode": 'auto', "allowed_function_names": None} - }) + generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), + tools=None, + tool_config=content_types.to_tool_config( + {"function_calling_config": {"mode": "auto", "allowed_function_names": None}} + ), ) assert response == "I'm doing well, thank you for asking!" - + + def test_generate_response_with_tools(mock_gemini_client: Mock): config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) llm = GeminiLLM(config) @@ -74,13 +74,13 @@ def test_generate_response_with_tools(mock_gemini_client: Mock): mock_part = Mock() mock_part.function_call = mock_tool_call - mock_part.text="I've added the memory for you." + mock_part.text = "I've added the memory for you." mock_content = Mock() - mock_content.parts=[mock_part] + mock_content.parts = [mock_part] mock_message = Mock() - mock_message.content=mock_content + mock_message.content = mock_content mock_response = Mock(candidates=[mock_message]) mock_gemini_client.generate_content.return_value = mock_response @@ -88,28 +88,29 @@ def test_generate_response_with_tools(mock_gemini_client: Mock): response = llm.generate_response(messages, tools=tools) mock_gemini_client.generate_content.assert_called_once_with( - contents = [ - {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, - {"parts": "Add a new memory: Today is a sunny day.", "role": "user"} + contents=[ + {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, + {"parts": "Add a new memory: Today is a sunny day.", "role": "user"}, ], - generation_config = GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), - tools = [ + generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), + tools=[ { - "function_declarations": [{ - "name": "add_memory", - "description": "Add a memory", - "parameters": { - "type": "object", - "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, - "required": ["data"] + "function_declarations": [ + { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, } - }] + ] } ], - tool_config = content_types.to_tool_config( - {"function_calling_config": - {"mode": 'auto', "allowed_function_names": None} - }) + tool_config=content_types.to_tool_config( + {"function_calling_config": {"mode": "auto", "allowed_function_names": None}} + ), ) assert response["content"] == "I've added the memory for you." diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 9e62c6f2cb..42f8aa5be4 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -31,8 +31,9 @@ def test_openai_llm_base_url(): # case3: with config.openai_base_url config_base_url = "https://api.config.com/v1" - config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", - openai_base_url=config_base_url) + config = BaseLlmConfig( + model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", openai_base_url=config_base_url + ) llm = OpenAILLM(config) # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash assert str(llm.client.base_url) == config_base_url + "/" diff --git a/tests/test_main.py b/tests/test_main.py index 66bc6d55d8..a311f854b2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -135,7 +135,9 @@ def test_update(memory_instance): result = memory_instance.update("test_id", "Updated memory") - memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory", {"Updated memory": [0.1, 0.2, 0.3]}) + memory_instance._update_memory.assert_called_once_with( + "test_id", "Updated memory", {"Updated memory": [0.1, 0.2, 0.3]} + ) assert result["message"] == "Memory updated successfully!" @@ -177,7 +179,6 @@ def test_reset(memory_instance): memory_instance.db.reset = Mock() with patch.object(VectorStoreFactory, "create", return_value=Mock()) as mock_create: - memory_instance.reset() initial_vector_store.delete_col.assert_called_once() @@ -186,6 +187,7 @@ def test_reset(memory_instance): memory_instance.config.vector_store.provider, memory_instance.config.vector_store.config ) + @pytest.mark.parametrize( "version, enable_graph, expected_result", [ diff --git a/tests/vector_stores/test_chroma.py b/tests/vector_stores/test_chroma.py index 3d0c20b3dc..6995217f0f 100644 --- a/tests/vector_stores/test_chroma.py +++ b/tests/vector_stores/test_chroma.py @@ -1,6 +1,6 @@ from unittest.mock import Mock, patch import pytest -from mem0.vector_stores.chroma import ChromaDB, OutputData +from mem0.vector_stores.chroma import ChromaDB @pytest.fixture @@ -12,13 +12,9 @@ def mock_chromadb_client(): @pytest.fixture def chromadb_instance(mock_chromadb_client): mock_collection = Mock() - mock_chromadb_client.return_value.get_or_create_collection.return_value = ( - mock_collection - ) + mock_chromadb_client.return_value.get_or_create_collection.return_value = mock_collection - return ChromaDB( - collection_name="test_collection", client=mock_chromadb_client.return_value - ) + return ChromaDB(collection_name="test_collection", client=mock_chromadb_client.return_value) def test_insert_vectors(chromadb_instance, mock_chromadb_client): @@ -28,9 +24,7 @@ def test_insert_vectors(chromadb_instance, mock_chromadb_client): chromadb_instance.insert(vectors=vectors, payloads=payloads, ids=ids) - chromadb_instance.collection.add.assert_called_once_with( - ids=ids, embeddings=vectors, metadatas=payloads - ) + chromadb_instance.collection.add.assert_called_once_with(ids=ids, embeddings=vectors, metadatas=payloads) def test_search_vectors(chromadb_instance, mock_chromadb_client): @@ -44,9 +38,7 @@ def test_search_vectors(chromadb_instance, mock_chromadb_client): query = [[0.1, 0.2, 0.3]] results = chromadb_instance.search(query=query, limit=2) - chromadb_instance.collection.query.assert_called_once_with( - query_embeddings=query, where=None, n_results=2 - ) + chromadb_instance.collection.query.assert_called_once_with(query_embeddings=query, where=None, n_results=2) print(results, type(results)) assert len(results) == 2 @@ -68,9 +60,7 @@ def test_update_vector(chromadb_instance): new_vector = [0.7, 0.8, 0.9] new_payload = {"name": "updated_vector"} - chromadb_instance.update( - vector_id=vector_id, vector=new_vector, payload=new_payload - ) + chromadb_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload) chromadb_instance.collection.update.assert_called_once_with( ids=vector_id, embeddings=new_vector, metadatas=new_payload diff --git a/tests/vector_stores/test_qdrant.py b/tests/vector_stores/test_qdrant.py index b398335fad..ab80cde9d3 100644 --- a/tests/vector_stores/test_qdrant.py +++ b/tests/vector_stores/test_qdrant.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import uuid from qdrant_client import QdrantClient from qdrant_client.models import ( @@ -51,9 +51,7 @@ def test_insert(self): def test_search(self): query_vector = [0.1, 0.2] - self.client_mock.search.return_value = [ - {"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}} - ] + self.client_mock.search.return_value = [{"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}}] results = self.qdrant.search(query=query_vector, limit=1) @@ -83,9 +81,7 @@ def test_update(self): updated_vector = [0.2, 0.3] updated_payload = {"key": "updated_value"} - self.qdrant.update( - vector_id=vector_id, vector=updated_vector, payload=updated_payload - ) + self.qdrant.update(vector_id=vector_id, vector=updated_vector, payload=updated_payload) self.client_mock.upsert.assert_called_once() point = self.client_mock.upsert.call_args[1]["points"][0] @@ -95,9 +91,7 @@ def test_update(self): def test_get(self): vector_id = str(uuid.uuid4()) - self.client_mock.retrieve.return_value = [ - {"id": vector_id, "payload": {"key": "value"}} - ] + self.client_mock.retrieve.return_value = [{"id": vector_id, "payload": {"key": "value"}}] result = self.qdrant.get(vector_id=vector_id) @@ -108,23 +102,17 @@ def test_get(self): self.assertEqual(result["payload"], {"key": "value"}) def test_list_cols(self): - self.client_mock.get_collections.return_value = MagicMock( - collections=[{"name": "test_collection"}] - ) + self.client_mock.get_collections.return_value = MagicMock(collections=[{"name": "test_collection"}]) result = self.qdrant.list_cols() self.assertEqual(result.collections[0]["name"], "test_collection") def test_delete_col(self): self.qdrant.delete_col() - self.client_mock.delete_collection.assert_called_once_with( - collection_name="test_collection" - ) + self.client_mock.delete_collection.assert_called_once_with(collection_name="test_collection") def test_col_info(self): self.qdrant.col_info() - self.client_mock.get_collection.assert_called_once_with( - collection_name="test_collection" - ) + self.client_mock.get_collection.assert_called_once_with(collection_name="test_collection") def tearDown(self): del self.qdrant