From a86fe7769740132a995dca3aebe849b80f6094eb Mon Sep 17 00:00:00 2001 From: Mayank Solanki Date: Sat, 16 Nov 2024 00:15:24 +0530 Subject: [PATCH 1/2] redis db --- docs/components/vectordbs/dbs/redis.mdx | 34 ++++ docs/components/vectordbs/overview.mdx | 1 + docs/mint.json | 3 +- mem0/configs/vector_stores/redis.py | 23 +++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/redis.py | 211 ++++++++++++++++++++++++ 7 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 docs/components/vectordbs/dbs/redis.mdx create mode 100644 mem0/configs/vector_stores/redis.py create mode 100644 mem0/vector_stores/redis.py diff --git a/docs/components/vectordbs/dbs/redis.mdx b/docs/components/vectordbs/dbs/redis.mdx new file mode 100644 index 0000000000..b1c0f87b96 --- /dev/null +++ b/docs/components/vectordbs/dbs/redis.mdx @@ -0,0 +1,34 @@ +[Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data. + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "redis", + "config": { + "collection_name": "mem0", + "embedding_model_dims": 1536, + "redis_url": "redis://localhost:6379" + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +### Config + +Let's see the available parameters for the `redis` config: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `collection_name` | The name of the collection to store the vectors | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `redis_url` | The URL of the Redis server | `None` | \ No newline at end of file diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 822637f639..5364507a7e 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -14,6 +14,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/docs/mint.json b/docs/mint.json index 5d02819c5e..50e5819a60 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -111,7 +111,8 @@ "components/vectordbs/dbs/chroma", "components/vectordbs/dbs/pgvector", "components/vectordbs/dbs/milvus", - "components/vectordbs/dbs/azure_ai_search" + "components/vectordbs/dbs/azure_ai_search", + "components/vectordbs/dbs/redis" ] } ] diff --git a/mem0/configs/vector_stores/redis.py b/mem0/configs/vector_stores/redis.py new file mode 100644 index 0000000000..2b92c5297c --- /dev/null +++ b/mem0/configs/vector_stores/redis.py @@ -0,0 +1,23 @@ +from typing import Any, Dict + +from pydantic import BaseModel, Field, model_validator + +#TODO: Upgrade to latest pydantic version +class RedisDBConfig(BaseModel): + redis_url: str = Field(..., description="Redis URL") + collection_name: str = Field("mem0", description="Collection name") + embedding_model_dims: int = Field(1536, description="Embedding model dimensions") + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 5e0defc307..bdff8fe234 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -65,6 +65,7 @@ class VectorStoreFactory: "pgvector": "mem0.vector_stores.pgvector.PGVector", "milvus": "mem0.vector_stores.milvus.MilvusDB", "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", + "redis": "mem0.vector_stores.redis.RedisDB", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index c76e3a1178..75768d9661 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel): "pgvector": "PGVectorConfig", "milvus": "MilvusDBConfig", "azure_ai_search": "AzureAISearchConfig", + "redis": "RedisDBConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py new file mode 100644 index 0000000000..a623a524e9 --- /dev/null +++ b/mem0/vector_stores/redis.py @@ -0,0 +1,211 @@ +from datetime import datetime +from functools import reduce +import json +import logging +import numpy as np +import pytz +from mem0.vector_stores.base import VectorStoreBase +from redisvl.query.filter import Tag +from redisvl.query import VectorQuery +from redis.commands.search.query import Query + + +import redis +from redisvl.index import SearchIndex + + +logger = logging.getLogger(__name__) + +# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them. +DEFAULT_FIELDS = [ + {"name": "memory_id", "type": "tag"}, + {"name": "hash", "type": "tag"}, + {"name": "agent_id", "type": "tag"}, + {"name": "run_id", "type": "tag"}, + {"name": "user_id", "type": "tag"}, + {"name": "memory", "type": "text"}, + {"name": "metadata", "type": "text"}, + #TODO: Although it is numeric but also accepts string + {"name": "created_at", "type": "numeric"}, + {"name": "updated_at", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32" + } + } +] + +excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} + + +class MemoryResult: + def __init__(self, id: str, payload: dict, score: float = None): + self.id = id + self.payload = payload + self.score = score + +class RedisDB(VectorStoreBase): + def __init__( + self, + redis_url: str, + collection_name: str, + embedding_model_dims: int, + ): + """ + Initialize the Redis vector store. + + Args: + redis_url (str): Redis URL. + collection_name (str): Collection name. + embedding_model_dims (int): Embedding model dimensions. + """ + index_schema = { + "name": collection_name, + "prefix": f"mem0:{collection_name}", + } + + fields = DEFAULT_FIELDS.copy() + fields[-1]["attrs"]["dims"] = embedding_model_dims + + self.schema = { + "index": index_schema, + "fields": fields + } + + self.client = redis.Redis.from_url(redis_url) + self.index = SearchIndex.from_dict(self.schema) + self.index.set_client(self.client) + self.index.create(overwrite=True) + + #TODO: Implement multiindex support. + def create_col(self, name, vector_size, distance): + raise NotImplementedError("Collection/Index creation not supported yet.") + + def insert(self, vectors: list, payloads: list = None, ids: list = None): + data = [] + for vector, payload, id in zip(vectors, payloads, ids): + # Start with required fields + entry = { + "memory_id": id, + "hash": payload["hash"], + "memory": payload["data"], + "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), + "embedding": np.array(vector, dtype=np.float32).tobytes(), + } + + # Conditionally add optional fields + for field in ["agent_id", "run_id", "user_id"]: + if field in payload: + entry[field] = payload[field] + + # Add metadata excluding specific keys + entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) + + data.append(entry) + self.index.load(data, id_field="memory_id") + + def search(self, query: list, limit: int = 5, filters: dict = None): + conditions = [ + Tag(key) == value + for key, value in filters.items() + if value is not None + ] + filter = reduce(lambda x, y: x & y, conditions) + + v = VectorQuery( + vector=np.array(query, dtype=np.float32).tobytes(), + vector_field_name="embedding", + return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"], + filter_expression=filter, + num_results=limit + ) + + results = self.index.query(v) + + return [MemoryResult(id=result["memory_id"], + score=result["vector_distance"], + payload={ + "hash": result["hash"], + "data": result["memory"], + "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds'), + **({"updated_at": datetime.fromtimestamp(int(result["updated_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds')} if "updated_at" in result else {}), + **{field: result[field] + for field in ["agent_id", "run_id", "user_id"] + if field in result}, + **{k: v for k, v in json.loads(result["metadata"]).items()} + }) for result in results] + + def delete(self, vector_id): + self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}") + + def update(self, vector_id=None, vector=None, payload=None): + data = { + "memory_id": vector_id, + "hash": payload["hash"], + "memory": payload["data"], + "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), + "updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()), + "embedding": np.array(vector, dtype=np.float32).tobytes() + } + + for field in ["agent_id", "run_id", "user_id"]: + if field in payload: + data[field] = payload[field] + + data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) + self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id") + + def get(self, vector_id): + result = self.index.fetch(vector_id) + payload = { + "hash": result["hash"], + "data": result["memory"], + "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds'), + **({"updated_at": datetime.fromtimestamp(int(result["updated_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds')} if "updated_at" in result else {}), + **{field: result[field] + for field in ["agent_id", "run_id", "user_id"] + if field in result}, + **{k: v for k, v in json.loads(result["metadata"]).items()} + } + + return MemoryResult(id=result["memory_id"], payload=payload) + + def list_cols(self): + return self.index.listall() + + def delete_col(self): + self.index.delete() + + def col_info(self, name): + return self.index.info() + + def list(self, filters: dict = None, limit: int = None) -> list: + """ + List all recent created memories from the vector store. + """ + conditions = [ + Tag(key) == value + for key, value in filters.items() + if value is not None + ] + filter = reduce(lambda x, y: x & y, conditions) + query = Query(str(filter)).sort_by("created_at", asc=False) + if limit is not None: + query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit) + + results = self.index.search(query) + return [[MemoryResult(id=result["memory_id"], + payload={ + "hash": result["hash"], + "data": result["memory"], + "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds'), + **({"updated_at": datetime.fromtimestamp(int(result["updated_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds')} if result.__dict__.get("updated_at") else {}), + **{field: result[field] + for field in ["agent_id", "run_id", "user_id"] + if field in result.__dict__}, + **{k: v for k, v in json.loads(result["metadata"]).items()} + }) for result in results.docs]] From be2ad460a36f0c6bbad85944877c1c349429c108 Mon Sep 17 00:00:00 2001 From: Mayank Solanki Date: Sat, 16 Nov 2024 11:30:59 +0530 Subject: [PATCH 2/2] lint and docs --- docs/components/vectordbs/dbs/redis.mdx | 10 ++ mem0/configs/vector_stores/redis.py | 7 +- mem0/vector_stores/redis.py | 183 ++++++++++++++---------- 3 files changed, 119 insertions(+), 81 deletions(-) diff --git a/docs/components/vectordbs/dbs/redis.mdx b/docs/components/vectordbs/dbs/redis.mdx index b1c0f87b96..771a6589f8 100644 --- a/docs/components/vectordbs/dbs/redis.mdx +++ b/docs/components/vectordbs/dbs/redis.mdx @@ -1,5 +1,15 @@ [Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data. +### Installation +```bash +pip install redis redisvl +``` + +Redis Stack using Docker: +```bash +docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest +``` + ### Usage ```python diff --git a/mem0/configs/vector_stores/redis.py b/mem0/configs/vector_stores/redis.py index 2b92c5297c..efa442dc12 100644 --- a/mem0/configs/vector_stores/redis.py +++ b/mem0/configs/vector_stores/redis.py @@ -2,7 +2,8 @@ from pydantic import BaseModel, Field, model_validator -#TODO: Upgrade to latest pydantic version + +# TODO: Upgrade to latest pydantic version class RedisDBConfig(BaseModel): redis_url: str = Field(..., description="Redis URL") collection_name: str = Field("mem0", description="Collection name") @@ -15,7 +16,9 @@ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: input_fields = set(values.keys()) extra_fields = input_fields - allowed_fields if extra_fields: - raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) return values model_config = { diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py index a623a524e9..0f553f32b5 100644 --- a/mem0/vector_stores/redis.py +++ b/mem0/vector_stores/redis.py @@ -1,18 +1,17 @@ -from datetime import datetime -from functools import reduce import json import logging +from datetime import datetime +from functools import reduce + import numpy as np import pytz -from mem0.vector_stores.base import VectorStoreBase -from redisvl.query.filter import Tag -from redisvl.query import VectorQuery -from redis.commands.search.query import Query - - import redis +from redis.commands.search.query import Query from redisvl.index import SearchIndex +from redisvl.query import VectorQuery +from redisvl.query.filter import Tag +from mem0.vector_stores.base import VectorStoreBase logger = logging.getLogger(__name__) @@ -25,18 +24,14 @@ {"name": "user_id", "type": "tag"}, {"name": "memory", "type": "text"}, {"name": "metadata", "type": "text"}, - #TODO: Although it is numeric but also accepts string + # TODO: Although it is numeric but also accepts string {"name": "created_at", "type": "numeric"}, {"name": "updated_at", "type": "numeric"}, { "name": "embedding", "type": "vector", - "attrs": { - "distance_metric": "cosine", - "algorithm": "flat", - "datatype": "float32" - } - } + "attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"}, + }, ] excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} @@ -48,6 +43,7 @@ def __init__(self, id: str, payload: dict, score: float = None): self.payload = payload self.score = score + class RedisDB(VectorStoreBase): def __init__( self, @@ -64,27 +60,24 @@ def __init__( embedding_model_dims (int): Embedding model dimensions. """ index_schema = { - "name": collection_name, - "prefix": f"mem0:{collection_name}", + "name": collection_name, + "prefix": f"mem0:{collection_name}", } fields = DEFAULT_FIELDS.copy() fields[-1]["attrs"]["dims"] = embedding_model_dims - self.schema = { - "index": index_schema, - "fields": fields - } + self.schema = {"index": index_schema, "fields": fields} self.client = redis.Redis.from_url(redis_url) self.index = SearchIndex.from_dict(self.schema) self.index.set_client(self.client) self.index.create(overwrite=True) - - #TODO: Implement multiindex support. + + # TODO: Implement multiindex support. def create_col(self, name, vector_size, distance): raise NotImplementedError("Collection/Index creation not supported yet.") - + def insert(self, vectors: list, payloads: list = None, ids: list = None): data = [] for vector, payload, id in zip(vectors, payloads, ids): @@ -96,48 +89,57 @@ def insert(self, vectors: list, payloads: list = None, ids: list = None): "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), "embedding": np.array(vector, dtype=np.float32).tobytes(), } - + # Conditionally add optional fields for field in ["agent_id", "run_id", "user_id"]: if field in payload: entry[field] = payload[field] - + # Add metadata excluding specific keys entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) - + data.append(entry) self.index.load(data, id_field="memory_id") def search(self, query: list, limit: int = 5, filters: dict = None): - conditions = [ - Tag(key) == value - for key, value in filters.items() - if value is not None - ] + conditions = [Tag(key) == value for key, value in filters.items() if value is not None] filter = reduce(lambda x, y: x & y, conditions) v = VectorQuery( vector=np.array(query, dtype=np.float32).tobytes(), vector_field_name="embedding", return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"], - filter_expression=filter, - num_results=limit + filter_expression=filter, + num_results=limit, ) results = self.index.query(v) - return [MemoryResult(id=result["memory_id"], - score=result["vector_distance"], - payload={ - "hash": result["hash"], - "data": result["memory"], - "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds'), - **({"updated_at": datetime.fromtimestamp(int(result["updated_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds')} if "updated_at" in result else {}), - **{field: result[field] - for field in ["agent_id", "run_id", "user_id"] - if field in result}, - **{k: v for k, v in json.loads(result["metadata"]).items()} - }) for result in results] + return [ + MemoryResult( + id=result["memory_id"], + score=result["vector_distance"], + payload={ + "hash": result["hash"], + "data": result["memory"], + "created_at": datetime.fromtimestamp( + int(result["created_at"]), tz=pytz.timezone("US/Pacific") + ).isoformat(timespec="microseconds"), + **( + { + "updated_at": datetime.fromtimestamp( + int(result["updated_at"]), tz=pytz.timezone("US/Pacific") + ).isoformat(timespec="microseconds") + } + if "updated_at" in result + else {} + ), + **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, + **{k: v for k, v in json.loads(result["metadata"]).items()}, + }, + ) + for result in results + ] def delete(self, vector_id): self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}") @@ -149,14 +151,14 @@ def update(self, vector_id=None, vector=None, payload=None): "memory": payload["data"], "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), "updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()), - "embedding": np.array(vector, dtype=np.float32).tobytes() + "embedding": np.array(vector, dtype=np.float32).tobytes(), } - + for field in ["agent_id", "run_id", "user_id"]: if field in payload: data[field] = payload[field] - - data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) + + data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id") def get(self, vector_id): @@ -164,48 +166,71 @@ def get(self, vector_id): payload = { "hash": result["hash"], "data": result["memory"], - "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds'), - **({"updated_at": datetime.fromtimestamp(int(result["updated_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds')} if "updated_at" in result else {}), - **{field: result[field] - for field in ["agent_id", "run_id", "user_id"] - if field in result}, - **{k: v for k, v in json.loads(result["metadata"]).items()} + "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat( + timespec="microseconds" + ), + **( + { + "updated_at": datetime.fromtimestamp( + int(result["updated_at"]), tz=pytz.timezone("US/Pacific") + ).isoformat(timespec="microseconds") + } + if "updated_at" in result + else {} + ), + **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, + **{k: v for k, v in json.loads(result["metadata"]).items()}, } - + return MemoryResult(id=result["memory_id"], payload=payload) - + def list_cols(self): return self.index.listall() - + def delete_col(self): self.index.delete() - + def col_info(self, name): return self.index.info() - + def list(self, filters: dict = None, limit: int = None) -> list: """ List all recent created memories from the vector store. """ - conditions = [ - Tag(key) == value - for key, value in filters.items() - if value is not None - ] + conditions = [Tag(key) == value for key, value in filters.items() if value is not None] filter = reduce(lambda x, y: x & y, conditions) - query = Query(str(filter)).sort_by("created_at", asc=False) + query = Query(str(filter)).sort_by("created_at", asc=False) if limit is not None: query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit) - + results = self.index.search(query) - return [[MemoryResult(id=result["memory_id"], - payload={ - "hash": result["hash"], - "data": result["memory"], - "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds'), - **({"updated_at": datetime.fromtimestamp(int(result["updated_at"]), tz=pytz.timezone('US/Pacific')).isoformat(timespec='microseconds')} if result.__dict__.get("updated_at") else {}), - **{field: result[field] - for field in ["agent_id", "run_id", "user_id"] - if field in result.__dict__}, - **{k: v for k, v in json.loads(result["metadata"]).items()} - }) for result in results.docs]] + return [ + [ + MemoryResult( + id=result["memory_id"], + payload={ + "hash": result["hash"], + "data": result["memory"], + "created_at": datetime.fromtimestamp( + int(result["created_at"]), tz=pytz.timezone("US/Pacific") + ).isoformat(timespec="microseconds"), + **( + { + "updated_at": datetime.fromtimestamp( + int(result["updated_at"]), tz=pytz.timezone("US/Pacific") + ).isoformat(timespec="microseconds") + } + if result.__dict__.get("updated_at") + else {} + ), + **{ + field: result[field] + for field in ["agent_id", "run_id", "user_id"] + if field in result.__dict__ + }, + **{k: v for k, v in json.loads(result["metadata"]).items()}, + }, + ) + for result in results.docs + ] + ]