Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom prompt override in memory.add function #1998

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/features/custom-prompts.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,20 @@ m.add("I like going to hikes", user_id="alice")
}
```
</CodeGroup>


## Customizing Prompts per Memory Addition

In addition to setting a default prompt in the configuration, you can also override prompts for individual memory entries by using the prompt and graph_prompt parameters in m.add(). This allows you to tailor specific entries without changing the overall configuration.

For example, to add a memory with a custom prompt:

```python Code
m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice", prompt=custom_prompt)
```

You can also use graph_prompt to customize the prompt specifically for graph memory entries:

```python Code
m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice", graph_prompt=graph_prompt)
```
6 changes: 6 additions & 0 deletions docs/open-source/graph_memory/features.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ config = {
m = Memory.from_config(config_dict=config)
```

You can also **override prompts** for individual memory additions by using the `graph_prompt` parameter in `m.add()`

```python Code
m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice", graph_prompt=graph_prompt)
```

If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods:

<Snippet file="get-help.mdx" />
23 changes: 12 additions & 11 deletions mem0/memory/graph_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, config):
self.user_id = None
self.threshold = 0.7

def add(self, data, filters):
def add(self, data, filters, graph_prompt=None):
"""
Adds data to the graph.

Expand All @@ -55,7 +55,7 @@ def add(self, data, filters):
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map, graph_prompt)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)

Expand Down Expand Up @@ -173,14 +173,15 @@ def _retrieve_nodes_from_data(self, data, filters):
logger.debug(f"Entity type map: {entity_type_map}")
return entity_type_map

def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map, graph_prompt=None):
"""Eshtablish relations among the extracted nodes."""
if self.config.graph_store.custom_prompt:
custom_prompt = graph_prompt if graph_prompt else self.config.graph_store.custom_prompt
if custom_prompt:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
"CUSTOM_PROMPT", f"4. {custom_prompt}"
),
},
{"role": "user", "content": data},
Expand Down Expand Up @@ -294,7 +295,7 @@ def _delete_entities(self, to_be_deleted, user_id):
-[r:{relatationship}]->
(m {{name: $dest_name, user_id: $user_id}})
DELETE r
RETURN
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
Expand Down Expand Up @@ -339,7 +340,7 @@ def _add_entities(self, to_be_added, user_id, entity_type_map):
destination.created = timestamp(),
destination.embedding = $destination_embedding
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
Expand All @@ -364,7 +365,7 @@ def _add_entities(self, to_be_added, user_id, entity_type_map):
source.created = timestamp(),
source.embedding = $source_embedding
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
Expand All @@ -387,7 +388,7 @@ def _add_entities(self, to_be_added, user_id, entity_type_map):
MATCH (destination)
WHERE elementId(destination) = $destination_id
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
Expand Down Expand Up @@ -436,7 +437,7 @@ def _remove_spaces_from_entities(self, entity_list):
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = """
MATCH (source_candidate)
WHERE source_candidate.embedding IS NOT NULL
WHERE source_candidate.embedding IS NOT NULL
AND source_candidate.user_id = $user_id

WITH source_candidate,
Expand Down Expand Up @@ -469,7 +470,7 @@ def _search_source_node(self, source_embedding, user_id, threshold=0.9):
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = """
MATCH (destination_candidate)
WHERE destination_candidate.embedding IS NOT NULL
WHERE destination_candidate.embedding IS NOT NULL
AND destination_candidate.user_id = $user_id

WITH destination_candidate,
Expand Down
17 changes: 10 additions & 7 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def add(
metadata=None,
filters=None,
prompt=None,
graph_prompt=None,
):
"""
Create a new memory.
Expand All @@ -83,6 +84,7 @@ def add(
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
graph_prompt (str, optional): Prompt to use for graph memory deduction. Defaults to None.

Returns:
dict: A dictionary containing the result of the memory addition operation.
Expand Down Expand Up @@ -115,8 +117,8 @@ def add(
messages = [{"role": "user", "content": messages}]

with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, prompt)
future2 = executor.submit(self._add_to_graph, messages, filters, graph_prompt)

concurrent.futures.wait([future1, future2])

Expand All @@ -138,11 +140,12 @@ def add(
)
return vector_store_result

def _add_to_vector_store(self, messages, metadata, filters):
def _add_to_vector_store(self, messages, metadata, filters, prompt=None):
parsed_messages = parse_messages(messages)

if self.custom_prompt:
system_prompt = self.custom_prompt
custom_prompt = prompt if prompt else self.custom_prompt
if custom_prompt:
system_prompt = custom_prompt
user_prompt = f"Input: {parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
Expand Down Expand Up @@ -244,14 +247,14 @@ def _add_to_vector_store(self, messages, metadata, filters):

return returned_memories

def _add_to_graph(self, messages, filters):
def _add_to_graph(self, messages, filters, graph_prompt=None):
added_entities = []
if self.api_version == "v1.1" and self.enable_graph:
if filters.get("user_id") is None:
filters["user_id"] = "user"

data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters)
added_entities = self.graph.add(data, filters, graph_prompt)

return added_entities

Expand Down
23 changes: 18 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,40 @@ def memory_instance():
return Memory(config)


@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_add(memory_instance, version, enable_graph):
@pytest.mark.parametrize(
"version, enable_graph, custom_prompt",
[
("v1.0", False, None),
("v1.1", True, None),
("v1.0", False, "CustomPrompt"),
("v1.1", True, "CustomPrompt"),
]
)
def test_add(memory_instance, version, enable_graph, custom_prompt):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
memory_instance._add_to_graph = Mock(return_value=[])

result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user")
result = memory_instance.add(
messages=[{"role": "user", "content": "Test message"}],
user_id="test_user",
prompt=custom_prompt,
graph_prompt=custom_prompt
)

assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
assert "relations" in result
assert result["relations"] == []

memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, custom_prompt
)

# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, custom_prompt
)


Expand Down
Loading