Skip to content

Commit

Permalink
Adding Gemini (#1862)
Browse files Browse the repository at this point in the history
  • Loading branch information
PranavPuranik authored Sep 27, 2024
1 parent 699741c commit aaf8e6e
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 4 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ install:

install_all:
poetry install
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
google-generativeai

# Format code with ruff
format:
Expand Down
41 changes: 41 additions & 0 deletions docs/components/embedders/models/gemini.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
title: Gemini
---

To use Gemini embedding models, set the `GOOGLE_API_KEY` environment variables. You can obtain the Gemini API key from [here](https://aistudio.google.com/app/apikey).

### Usage

```python
import os
from mem0 import Memory

os.environ["GOOGLE_API_KEY"] = "key"

config = {
"embedder": {
"provider": "gemini",
"config": {
"model": "models/text-embedding-004"
}
},
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": "test",
"embedding_model_dims": 768,
}
},
}

m = Memory.from_config(config)
m.add("I'm visiting Paris", user_id="john")
```

### Config

Here are the parameters available for configuring Gemini embedder:

| Parameter | Description | Default Value |
| --- | --- | --- |
| `model` | The name of the embedding model to use | `models/text-embedding-004` |
1 change: 1 addition & 0 deletions docs/components/embedders/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the list of supported embedders below.
<Card title="Azure OpenAI" href="/components/embedders/models/azure_openai"></Card>
<Card title="Ollama" href="/components/embedders/models/ollama"></Card>
<Card title="Hugging Face" href="/components/embedders/models/huggingface"></Card>
<Card title="Gemini" href="/components/embedders/models/gemini"></Card>
<Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card>
</CardGroup>

Expand Down
3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@
"components/embedders/models/openai",
"components/embedders/models/azure_openai",
"components/embedders/models/ollama",
"components/embedders/models/huggingface"
"components/embedders/models/huggingface",
"components/embedders/models/gemini"
]
}
]
Expand Down
2 changes: 1 addition & 1 deletion mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class EmbedderConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "vertexai"]:
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai"]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
27 changes: 27 additions & 0 deletions mem0/embeddings/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
from typing import Optional
import google.generativeai as genai

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase


class GoogleGenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
self.config.model = "models/text-embedding-004" # embedding-dim = 768

genai.configure(api_key=self.config.api_key or os.getenv("GOOGLE_API_KEY"))

def embed(self, text):
"""
Get the embedding for the given text using Google Generative AI.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text)
return response['embedding']
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EmbedderFactory:
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
}

@classmethod
Expand Down
Loading

0 comments on commit aaf8e6e

Please sign in to comment.