Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Gemini (#1490) #1965

Merged
merged 12 commits into from
Oct 21, 2024
33 changes: 33 additions & 0 deletions docs/components/llms/models/gemini.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
title: Gemini
---

To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from the [Google AI Studio](https://aistudio.google.com/app/apikey)

## Usage

```python
import os
from mem0 import Memory

os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
os.environ["GEMINI_API_KEY"] = "your-api-key"

config = {
"llm": {
"provider": "gemini",
"config": {
"model": "gemini-1.5-flash-latest",
"temperature": 0.2,
"max_tokens": 1500,
}
}
}

m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```

## Config

All available parameters for the `Gemini` config are present in [Master List of All Params in Config](../config).
1 change: 1 addition & 0 deletions docs/components/llms/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ To view all supported llms, visit the [Supported LLMs](./models).
<Card title="Mistral AI" href="/components/llms/models/mistral_ai"></Card>
<Card title="Google AI" href="/components/llms/models/google_ai"></Card>
<Card title="AWS bedrock" href="/components/llms/models/aws_bedrock"></Card>
<Card title="Gemini" href="/components/llms/models/gemini"></Card>
</CardGroup>

## Structured vs Unstructured Outputs
Expand Down
3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@
"components/llms/models/litellm",
"components/llms/models/mistral_AI",
"components/llms/models/google_AI",
"components/llms/models/aws_bedrock"
"components/llms/models/aws_bedrock",
"components/llms/models/gemini"
]
}
]
Expand Down
154 changes: 154 additions & 0 deletions mem0/llms/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
from typing import Dict, List, Optional

try:
import google.generativeai as genai
from google.generativeai import GenerativeModel
from google.generativeai.types import content_types
except ImportError:
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")

from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase


class GeminiLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model = "gemini-1.5-flash-latest"

api_key = self.config.api_key or os.getenv("GEMINI_API_KEY")
genai.configure(api_key=api_key)
self.client = GenerativeModel(model_name=self.config.model)

def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.

Args:
response: The raw response from API.
tools: The list of tools provided in the request.

Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": content if (content := response.candidates[0].content.parts[0].text) else None,
"tool_calls": [],
}

for part in response.candidates[0].content.parts:
if fn := part.function_call:
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": {key:val for key, val in fn.args.items()},
}
)

return processed_response
else:
return response.candidates[0].content.parts[0].text

def _reformat_messages(self, messages : List[Dict[str, str]]):
"""
Reformat messages for Gemini.

Args:
messages: The list of messages provided in the request.

Returns:
list: The list of messages in the required format.
"""
new_messages = []

for message in messages:
if message["role"] == "system":
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]

else:
content = message["content"]

new_messages.append({"parts": content,
"role": "model" if message["role"] == "model" else "user"})

return new_messages

def _reformat_tools(self, tools: Optional[List[Dict]]):
"""
Reformat tools for Gemini.

Args:
tools: The list of tools provided in the request.

Returns:
list: The list of tools in the required format.
"""

def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""

if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data

new_tools = []
if tools:
for tool in tools:
func = tool['function'].copy()
new_tools.append({"function_declarations":[remove_additional_properties(func)]})

return new_tools
else:
return None

def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Gemini.

Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format for the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".

Returns:
str: The generated response.
"""

params = {
"temperature": self.config.temperature,
"max_output_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}

if response_format:
params["response_mime_type"] = "application/json"
params["response_schema"] = list[response_format]
if tool_choice:
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
})

response = self.client.generate_content(contents = self._reformat_messages(messages),
tools = self._reformat_tools(tools),
generation_config = genai.GenerationConfig(**params),
tool_config = tool_config)

return self._parse_response(response, tools)
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LlmFactory:
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
"anthropic": "mem0.llms.anthropic.AnthropicLLM",
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM",
"gemini": "mem0.llms.gemini.GeminiLLM",
}

@classmethod
Expand Down
118 changes: 118 additions & 0 deletions tests/llms/test_gemini_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from unittest.mock import Mock, patch

import pytest
from google.generativeai import GenerationConfig
from google.generativeai.types import content_types

from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.gemini import GeminiLLM


@pytest.fixture
def mock_gemini_client():
with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini:
mock_client = Mock()
mock_gemini.return_value = mock_client
yield mock_client


def test_generate_response_without_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
]

mock_part = Mock(text="I'm doing well, thank you for asking!")
mock_content = Mock(parts=[mock_part])
mock_message = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response

response = llm.generate_response(messages)

mock_gemini_client.generate_content.assert_called_once_with(
contents = [
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Hello, how are you?", "role": "user"}
],
generation_config = GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
tools = None,
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": 'auto', "allowed_function_names": None}
})
)
assert response == "I'm doing well, thank you for asking!"

def test_generate_response_with_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]

mock_tool_call = Mock()
mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."}

mock_part = Mock()
mock_part.function_call = mock_tool_call
mock_part.text="I've added the memory for you."

mock_content = Mock()
mock_content.parts=[mock_part]

mock_message = Mock()
mock_message.content=mock_content

mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response

response = llm.generate_response(messages, tools=tools)

mock_gemini_client.generate_content.assert_called_once_with(
contents = [
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"}
],
generation_config = GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
tools = [
{
"function_declarations": [{
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"]
}
}]
}
],
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": 'auto', "allowed_function_names": None}
})
)

assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
Loading