diff --git a/Makefile b/Makefile index c3fb713da3..7e61ed8a18 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,8 @@ install: install_all: poetry install - poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai + poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \ + google-generativeai # Format code with ruff format: diff --git a/docs/components/embedders/models/gemini.mdx b/docs/components/embedders/models/gemini.mdx new file mode 100644 index 0000000000..0913a8a7ea --- /dev/null +++ b/docs/components/embedders/models/gemini.mdx @@ -0,0 +1,41 @@ +--- +title: Gemini +--- + +To use Gemini embedding models, set the `GOOGLE_API_KEY` environment variables. You can obtain the Gemini API key from [here](https://aistudio.google.com/app/apikey). + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["GOOGLE_API_KEY"] = "key" + +config = { + "embedder": { + "provider": "gemini", + "config": { + "model": "models/text-embedding-004" + } + }, + "vector_store": { + "provider": "qdrant", + "config": { + "collection_name": "test", + "embedding_model_dims": 768, + } + }, +} + +m = Memory.from_config(config) +m.add("I'm visiting Paris", user_id="john") +``` + +### Config + +Here are the parameters available for configuring Gemini embedder: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `model` | The name of the embedding model to use | `models/text-embedding-004` | diff --git a/docs/components/embedders/overview.mdx b/docs/components/embedders/overview.mdx index f1d7d7e584..48e2cfdc78 100644 --- a/docs/components/embedders/overview.mdx +++ b/docs/components/embedders/overview.mdx @@ -13,6 +13,7 @@ See the list of supported embedders below. + diff --git a/docs/mint.json b/docs/mint.json index 268229be24..5e301898bb 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -125,7 +125,8 @@ "components/embedders/models/openai", "components/embedders/models/azure_openai", "components/embedders/models/ollama", - "components/embedders/models/huggingface" + "components/embedders/models/huggingface", + "components/embedders/models/gemini" ] } ] diff --git a/mem0/embeddings/configs.py b/mem0/embeddings/configs.py index 213493440b..90219ce6c3 100644 --- a/mem0/embeddings/configs.py +++ b/mem0/embeddings/configs.py @@ -13,7 +13,7 @@ class EmbedderConfig(BaseModel): @field_validator("config") def validate_config(cls, v, values): provider = values.data.get("provider") - if provider in ["openai", "ollama", "huggingface", "azure_openai", "vertexai"]: + if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai"]: return v else: raise ValueError(f"Unsupported embedding provider: {provider}") diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py new file mode 100644 index 0000000000..06efde83da --- /dev/null +++ b/mem0/embeddings/gemini.py @@ -0,0 +1,27 @@ +import os +from typing import Optional +import google.generativeai as genai + +from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.base import EmbeddingBase + + +class GoogleGenAIEmbedding(EmbeddingBase): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + if self.config.model is None: + self.config.model = "models/text-embedding-004" # embedding-dim = 768 + + genai.configure(api_key=self.config.api_key or os.getenv("GOOGLE_API_KEY")) + + def embed(self, text): + """ + Get the embedding for the given text using Google Generative AI. + Args: + text (str): The text to embed. + Returns: + list: The embedding vector. + """ + text = text.replace("\n", " ") + response = genai.embed_content(model=self.config.model, content=text) + return response['embedding'] \ No newline at end of file diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 21c0445914..034fc6fc5b 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -41,6 +41,7 @@ class EmbedderFactory: "ollama": "mem0.embeddings.ollama.OllamaEmbedding", "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding", "azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding", + "gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding", } @classmethod diff --git a/poetry.lock b/poetry.lock index 10a27332b7..88058ce0a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -211,6 +211,17 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "cachetools" +version = "5.5.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, + {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, +] + [[package]] name = "certifi" version = "2024.7.4" @@ -458,6 +469,150 @@ files = [ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, ] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.9" +description = "Google Ai Generativelanguage API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_ai_generativelanguage-0.6.9-py3-none-any.whl", hash = "sha256:50360cd80015d1a8cc70952e98560f32fa06ddee2e8e9f4b4b98e431dc561e0b"}, + {file = "google_ai_generativelanguage-0.6.9.tar.gz", hash = "sha256:899f1d3a06efa9739f1cd9d2788070178db33c89d4a76f2e8f4da76f649155fa"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" + +[[package]] +name = "google-api-core" +version = "2.20.0" +description = "Google API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_api_core-2.20.0-py3-none-any.whl", hash = "sha256:ef0591ef03c30bb83f79b3d0575c3f31219001fc9c5cf37024d08310aeffed8a"}, + {file = "google_api_core-2.20.0.tar.gz", hash = "sha256:f74dff1889ba291a4b76c5079df0711810e2d9da81abfdc99957bc961c1eb28f"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +grpcio = [ + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, +] +grpcio-status = [ + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, +] +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-api-python-client" +version = "2.146.0" +description = "Google API Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_api_python_client-2.146.0-py2.py3-none-any.whl", hash = "sha256:b1e62c9889c5ef6022f11d30d7ef23dc55100300f0e8aaf8aa09e8e92540acad"}, + {file = "google_api_python_client-2.146.0.tar.gz", hash = "sha256:41f671be10fa077ee5143ee9f0903c14006d39dc644564f4e044ae96b380bf68"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + +[[package]] +name = "google-auth" +version = "2.35.0" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_auth-2.35.0-py2.py3-none-any.whl", hash = "sha256:25df55f327ef021de8be50bad0dfd4a916ad0de96da86cd05661c9297723ad3f"}, + {file = "google_auth-2.35.0.tar.gz", hash = "sha256:f4c64ed4e01e8e8b646ef34c018f8bf3338df0c8e37d8b3bba40e7f574a3278a"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography", "pyopenssl"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = false +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + +[[package]] +name = "google-generativeai" +version = "0.8.1" +description = "Google Generative AI High level API client library and tools." +optional = false +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.8.1-py3-none-any.whl", hash = "sha256:b031877f24d51af0945207657c085896a0a886eceec7a1cb7029327b0aa6e2f6"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.6.9" +google-api-core = "*" +google-api-python-client = "*" +google-auth = ">=2.15.0" +protobuf = "*" +pydantic = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + +[[package]] +name = "googleapis-common-protos" +version = "1.65.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis_common_protos-1.65.0-py2.py3-none-any.whl", hash = "sha256:2972e6c496f435b92590fd54045060867f3fe9be2c82ab148fc8885035479a63"}, + {file = "googleapis_common_protos-1.65.0.tar.gz", hash = "sha256:334a29d07cddc3aa01dee4988f9afd9b2916ee2ff49d6b757155dc0d197852c0"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + [[package]] name = "greenlet" version = "3.0.3" @@ -587,6 +742,22 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.64.1)"] +[[package]] +name = "grpcio-status" +version = "1.62.3" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-status-1.62.3.tar.gz", hash = "sha256:289bdd7b2459794a12cf95dc0cb727bd4a1742c37bd823f760236c937e53a485"}, + {file = "grpcio_status-1.62.3-py3-none-any.whl", hash = "sha256:f9049b762ba8de6b1086789d8315846e094edac2c50beaf462338b301a8fd4b8"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.62.3" +protobuf = ">=4.21.6" + [[package]] name = "grpcio-tools" version = "1.62.2" @@ -713,6 +884,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<0.26.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.0" @@ -1281,6 +1466,23 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] +[[package]] +name = "proto-plus" +version = "1.24.0" +description = "Beautiful, Pythonic protocol buffers." +optional = false +python-versions = ">=3.7" +files = [ + {file = "proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445"}, + {file = "proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<6.0.0dev" + +[package.extras] +testing = ["google-api-core (>=1.31.5)"] + [[package]] name = "protobuf" version = "4.25.4" @@ -1301,6 +1503,31 @@ files = [ {file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"}, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +description = "A collection of ASN.1-based protocols modules" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.7.0" + [[package]] name = "pydantic" version = "2.8.2" @@ -1424,6 +1651,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pyparsing" +version = "3.1.4" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.4-py3-none-any.whl", hash = "sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c"}, + {file = "pyparsing-3.1.4.tar.gz", hash = "sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pytest" version = "8.2.2" @@ -1621,6 +1862,20 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "ruff" version = "0.6.5" @@ -1844,6 +2099,17 @@ files = [ mypy-extensions = ">=0.3.0" typing-extensions = ">=3.7.4" +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = false +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.2.2" @@ -1967,4 +2233,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "56197730e020f77ee9824292f34348bbe935b42519b4027f6fb131084b88300b" +content-hash = "d0b1ade58154f0acd26f2d8ee6cb05c9ddcd10efa0fc5835af12a82e43628bbe" diff --git a/tests/embeddings/test_gemini.py b/tests/embeddings/test_gemini.py new file mode 100644 index 0000000000..07691632fa --- /dev/null +++ b/tests/embeddings/test_gemini.py @@ -0,0 +1,37 @@ +from unittest.mock import patch +import pytest +from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.gemini import GoogleGenAIEmbedding + + +@pytest.fixture +def mock_genai(): + with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai: + yield mock_genai + + +@pytest.fixture +def config(): + return BaseEmbedderConfig( + api_key="dummy_api_key", + model="test_model" + ) + + +def test_embed_query(mock_genai, config): + + mock_embedding_response = { + 'embedding': [0.1, 0.2, 0.3, 0.4] + } + mock_genai.return_value = mock_embedding_response + + embedder = GoogleGenAIEmbedding(config) + + text = "Hello, world!" + embedding = embedder.embed(text) + + assert embedding == [0.1, 0.2, 0.3, 0.4] + mock_genai.assert_called_once_with( + model="test_model", + content="Hello, world!" + )