diff --git a/Makefile b/Makefile
index c3fb713da3..7e61ed8a18 100644
--- a/Makefile
+++ b/Makefile
@@ -12,7 +12,8 @@ install:
install_all:
poetry install
- poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai
+ poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
+ google-generativeai
# Format code with ruff
format:
diff --git a/docs/components/embedders/models/gemini.mdx b/docs/components/embedders/models/gemini.mdx
new file mode 100644
index 0000000000..0913a8a7ea
--- /dev/null
+++ b/docs/components/embedders/models/gemini.mdx
@@ -0,0 +1,41 @@
+---
+title: Gemini
+---
+
+To use Gemini embedding models, set the `GOOGLE_API_KEY` environment variables. You can obtain the Gemini API key from [here](https://aistudio.google.com/app/apikey).
+
+### Usage
+
+```python
+import os
+from mem0 import Memory
+
+os.environ["GOOGLE_API_KEY"] = "key"
+
+config = {
+ "embedder": {
+ "provider": "gemini",
+ "config": {
+ "model": "models/text-embedding-004"
+ }
+ },
+ "vector_store": {
+ "provider": "qdrant",
+ "config": {
+ "collection_name": "test",
+ "embedding_model_dims": 768,
+ }
+ },
+}
+
+m = Memory.from_config(config)
+m.add("I'm visiting Paris", user_id="john")
+```
+
+### Config
+
+Here are the parameters available for configuring Gemini embedder:
+
+| Parameter | Description | Default Value |
+| --- | --- | --- |
+| `model` | The name of the embedding model to use | `models/text-embedding-004` |
diff --git a/docs/components/embedders/overview.mdx b/docs/components/embedders/overview.mdx
index f1d7d7e584..48e2cfdc78 100644
--- a/docs/components/embedders/overview.mdx
+++ b/docs/components/embedders/overview.mdx
@@ -13,6 +13,7 @@ See the list of supported embedders below.
+
diff --git a/docs/mint.json b/docs/mint.json
index 268229be24..5e301898bb 100644
--- a/docs/mint.json
+++ b/docs/mint.json
@@ -125,7 +125,8 @@
"components/embedders/models/openai",
"components/embedders/models/azure_openai",
"components/embedders/models/ollama",
- "components/embedders/models/huggingface"
+ "components/embedders/models/huggingface",
+ "components/embedders/models/gemini"
]
}
]
diff --git a/mem0/embeddings/configs.py b/mem0/embeddings/configs.py
index 213493440b..90219ce6c3 100644
--- a/mem0/embeddings/configs.py
+++ b/mem0/embeddings/configs.py
@@ -13,7 +13,7 @@ class EmbedderConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
- if provider in ["openai", "ollama", "huggingface", "azure_openai", "vertexai"]:
+ if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai"]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py
new file mode 100644
index 0000000000..06efde83da
--- /dev/null
+++ b/mem0/embeddings/gemini.py
@@ -0,0 +1,27 @@
+import os
+from typing import Optional
+import google.generativeai as genai
+
+from mem0.configs.embeddings.base import BaseEmbedderConfig
+from mem0.embeddings.base import EmbeddingBase
+
+
+class GoogleGenAIEmbedding(EmbeddingBase):
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
+ super().__init__(config)
+ if self.config.model is None:
+ self.config.model = "models/text-embedding-004" # embedding-dim = 768
+
+ genai.configure(api_key=self.config.api_key or os.getenv("GOOGLE_API_KEY"))
+
+ def embed(self, text):
+ """
+ Get the embedding for the given text using Google Generative AI.
+ Args:
+ text (str): The text to embed.
+ Returns:
+ list: The embedding vector.
+ """
+ text = text.replace("\n", " ")
+ response = genai.embed_content(model=self.config.model, content=text)
+ return response['embedding']
\ No newline at end of file
diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py
index 21c0445914..034fc6fc5b 100644
--- a/mem0/utils/factory.py
+++ b/mem0/utils/factory.py
@@ -41,6 +41,7 @@ class EmbedderFactory:
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
+ "gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
}
@classmethod
diff --git a/poetry.lock b/poetry.lock
index 10a27332b7..88058ce0a7 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -211,6 +211,17 @@ files = [
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
]
+[[package]]
+name = "cachetools"
+version = "5.5.0"
+description = "Extensible memoizing collections and decorators"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"},
+ {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"},
+]
+
[[package]]
name = "certifi"
version = "2024.7.4"
@@ -458,6 +469,150 @@ files = [
{file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
]
+[[package]]
+name = "google-ai-generativelanguage"
+version = "0.6.9"
+description = "Google Ai Generativelanguage API client library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google_ai_generativelanguage-0.6.9-py3-none-any.whl", hash = "sha256:50360cd80015d1a8cc70952e98560f32fa06ddee2e8e9f4b4b98e431dc561e0b"},
+ {file = "google_ai_generativelanguage-0.6.9.tar.gz", hash = "sha256:899f1d3a06efa9739f1cd9d2788070178db33c89d4a76f2e8f4da76f649155fa"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
+google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
+proto-plus = ">=1.22.3,<2.0.0dev"
+protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
+
+[[package]]
+name = "google-api-core"
+version = "2.20.0"
+description = "Google API client core library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google_api_core-2.20.0-py3-none-any.whl", hash = "sha256:ef0591ef03c30bb83f79b3d0575c3f31219001fc9c5cf37024d08310aeffed8a"},
+ {file = "google_api_core-2.20.0.tar.gz", hash = "sha256:f74dff1889ba291a4b76c5079df0711810e2d9da81abfdc99957bc961c1eb28f"},
+]
+
+[package.dependencies]
+google-auth = ">=2.14.1,<3.0.dev0"
+googleapis-common-protos = ">=1.56.2,<2.0.dev0"
+grpcio = [
+ {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
+ {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
+]
+grpcio-status = [
+ {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
+ {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
+]
+proto-plus = ">=1.22.3,<2.0.0dev"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
+requests = ">=2.18.0,<3.0.0.dev0"
+
+[package.extras]
+grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
+grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
+grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
+
+[[package]]
+name = "google-api-python-client"
+version = "2.146.0"
+description = "Google API Client Library for Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google_api_python_client-2.146.0-py2.py3-none-any.whl", hash = "sha256:b1e62c9889c5ef6022f11d30d7ef23dc55100300f0e8aaf8aa09e8e92540acad"},
+ {file = "google_api_python_client-2.146.0.tar.gz", hash = "sha256:41f671be10fa077ee5143ee9f0903c14006d39dc644564f4e044ae96b380bf68"},
+]
+
+[package.dependencies]
+google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0"
+google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0"
+google-auth-httplib2 = ">=0.2.0,<1.0.0"
+httplib2 = ">=0.19.0,<1.dev0"
+uritemplate = ">=3.0.1,<5"
+
+[[package]]
+name = "google-auth"
+version = "2.35.0"
+description = "Google Authentication Library"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "google_auth-2.35.0-py2.py3-none-any.whl", hash = "sha256:25df55f327ef021de8be50bad0dfd4a916ad0de96da86cd05661c9297723ad3f"},
+ {file = "google_auth-2.35.0.tar.gz", hash = "sha256:f4c64ed4e01e8e8b646ef34c018f8bf3338df0c8e37d8b3bba40e7f574a3278a"},
+]
+
+[package.dependencies]
+cachetools = ">=2.0.0,<6.0"
+pyasn1-modules = ">=0.2.1"
+rsa = ">=3.1.4,<5"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
+enterprise-cert = ["cryptography", "pyopenssl"]
+pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
+reauth = ["pyu2f (>=0.1.5)"]
+requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
+
+[[package]]
+name = "google-auth-httplib2"
+version = "0.2.0"
+description = "Google Authentication Library: httplib2 transport"
+optional = false
+python-versions = "*"
+files = [
+ {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"},
+ {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"},
+]
+
+[package.dependencies]
+google-auth = "*"
+httplib2 = ">=0.19.0"
+
+[[package]]
+name = "google-generativeai"
+version = "0.8.1"
+description = "Google Generative AI High level API client library and tools."
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "google_generativeai-0.8.1-py3-none-any.whl", hash = "sha256:b031877f24d51af0945207657c085896a0a886eceec7a1cb7029327b0aa6e2f6"},
+]
+
+[package.dependencies]
+google-ai-generativelanguage = "0.6.9"
+google-api-core = "*"
+google-api-python-client = "*"
+google-auth = ">=2.15.0"
+protobuf = "*"
+pydantic = "*"
+tqdm = "*"
+typing-extensions = "*"
+
+[package.extras]
+dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
+
+[[package]]
+name = "googleapis-common-protos"
+version = "1.65.0"
+description = "Common protobufs used in Google APIs"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "googleapis_common_protos-1.65.0-py2.py3-none-any.whl", hash = "sha256:2972e6c496f435b92590fd54045060867f3fe9be2c82ab148fc8885035479a63"},
+ {file = "googleapis_common_protos-1.65.0.tar.gz", hash = "sha256:334a29d07cddc3aa01dee4988f9afd9b2916ee2ff49d6b757155dc0d197852c0"},
+]
+
+[package.dependencies]
+protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
+
+[package.extras]
+grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
+
[[package]]
name = "greenlet"
version = "3.0.3"
@@ -587,6 +742,22 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.64.1)"]
+[[package]]
+name = "grpcio-status"
+version = "1.62.3"
+description = "Status proto mapping for gRPC"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "grpcio-status-1.62.3.tar.gz", hash = "sha256:289bdd7b2459794a12cf95dc0cb727bd4a1742c37bd823f760236c937e53a485"},
+ {file = "grpcio_status-1.62.3-py3-none-any.whl", hash = "sha256:f9049b762ba8de6b1086789d8315846e094edac2c50beaf462338b301a8fd4b8"},
+]
+
+[package.dependencies]
+googleapis-common-protos = ">=1.5.5"
+grpcio = ">=1.62.3"
+protobuf = ">=4.21.6"
+
[[package]]
name = "grpcio-tools"
version = "1.62.2"
@@ -713,6 +884,20 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.26.0)"]
+[[package]]
+name = "httplib2"
+version = "0.22.0"
+description = "A comprehensive HTTP client library."
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"},
+ {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"},
+]
+
+[package.dependencies]
+pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
+
[[package]]
name = "httpx"
version = "0.27.0"
@@ -1281,6 +1466,23 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"]
sentry = ["django", "sentry-sdk"]
test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"]
+[[package]]
+name = "proto-plus"
+version = "1.24.0"
+description = "Beautiful, Pythonic protocol buffers."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445"},
+ {file = "proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12"},
+]
+
+[package.dependencies]
+protobuf = ">=3.19.0,<6.0.0dev"
+
+[package.extras]
+testing = ["google-api-core (>=1.31.5)"]
+
[[package]]
name = "protobuf"
version = "4.25.4"
@@ -1301,6 +1503,31 @@ files = [
{file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"},
]
+[[package]]
+name = "pyasn1"
+version = "0.6.1"
+description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
+ {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
+]
+
+[[package]]
+name = "pyasn1-modules"
+version = "0.4.1"
+description = "A collection of ASN.1-based protocols modules"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
+ {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.4.6,<0.7.0"
+
[[package]]
name = "pydantic"
version = "2.8.2"
@@ -1424,6 +1651,20 @@ files = [
[package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
+[[package]]
+name = "pyparsing"
+version = "3.1.4"
+description = "pyparsing module - Classes and methods to define and execute parsing grammars"
+optional = false
+python-versions = ">=3.6.8"
+files = [
+ {file = "pyparsing-3.1.4-py3-none-any.whl", hash = "sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c"},
+ {file = "pyparsing-3.1.4.tar.gz", hash = "sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032"},
+]
+
+[package.extras]
+diagrams = ["jinja2", "railroad-diagrams"]
+
[[package]]
name = "pytest"
version = "8.2.2"
@@ -1621,6 +1862,20 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+[[package]]
+name = "rsa"
+version = "4.9"
+description = "Pure-Python RSA implementation"
+optional = false
+python-versions = ">=3.6,<4"
+files = [
+ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
+ {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.1.3"
+
[[package]]
name = "ruff"
version = "0.6.5"
@@ -1844,6 +2099,17 @@ files = [
mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4"
+[[package]]
+name = "uritemplate"
+version = "4.1.1"
+description = "Implementation of RFC 6570 URI Templates"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"},
+ {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"},
+]
+
[[package]]
name = "urllib3"
version = "2.2.2"
@@ -1967,4 +2233,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
-content-hash = "56197730e020f77ee9824292f34348bbe935b42519b4027f6fb131084b88300b"
+content-hash = "d0b1ade58154f0acd26f2d8ee6cb05c9ddcd10efa0fc5835af12a82e43628bbe"
diff --git a/tests/embeddings/test_gemini.py b/tests/embeddings/test_gemini.py
new file mode 100644
index 0000000000..07691632fa
--- /dev/null
+++ b/tests/embeddings/test_gemini.py
@@ -0,0 +1,37 @@
+from unittest.mock import patch
+import pytest
+from mem0.configs.embeddings.base import BaseEmbedderConfig
+from mem0.embeddings.gemini import GoogleGenAIEmbedding
+
+
+@pytest.fixture
+def mock_genai():
+ with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai:
+ yield mock_genai
+
+
+@pytest.fixture
+def config():
+ return BaseEmbedderConfig(
+ api_key="dummy_api_key",
+ model="test_model"
+ )
+
+
+def test_embed_query(mock_genai, config):
+
+ mock_embedding_response = {
+ 'embedding': [0.1, 0.2, 0.3, 0.4]
+ }
+ mock_genai.return_value = mock_embedding_response
+
+ embedder = GoogleGenAIEmbedding(config)
+
+ text = "Hello, world!"
+ embedding = embedder.embed(text)
+
+ assert embedding == [0.1, 0.2, 0.3, 0.4]
+ mock_genai.assert_called_once_with(
+ model="test_model",
+ content="Hello, world!"
+ )