Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Gemini (#1490) #1965

Merged
merged 12 commits into from
Oct 21, 2024
Merged
33 changes: 33 additions & 0 deletions docs/components/llms/models/gemini.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
title: Gemini
---

To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from the [Google AI Studio](https://aistudio.google.com/app/apikey)

## Usage

```python
import os
from mem0 import Memory

os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
os.environ["GEMINI_API_KEY"] = "your-api-key"

config = {
"llm": {
"provider": "gemini",
"config": {
"model": "gemini-1.5-flash-latest",
"temperature": 0.2,
"max_tokens": 1500,
}
}
}

m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```

## Config

All available parameters for the `Gemini` config are present in [Master List of All Params in Config](../config).
139 changes: 139 additions & 0 deletions mem0/llms/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
from typing import Dict, List, Optional

import google.generativeai as genai
from google.generativeai import GenerativeModel
from google.generativeai.types import content_types

from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase


class GeminiLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model = "gemini-1.5-flash-latest"

api_key = self.config.api_key or os.getenv("GEMINI_API_KEY")
genai.configure(api_key=api_key)
self.client = GenerativeModel(model_name=self.config.model)

def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.

Args:
response: The raw response from API.
tools: The list of tools provided in the request.

Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": content if (content := response.candidates[0].content.parts[0].text) else None,
"tool_calls": [],
}

for part in response.candidates[0].content.parts:
if fn := part.function_call:
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": {key:val for key, val in fn.args.items()},
}
)

return processed_response
else:
return response.candidates[0].content.parts[0].text

def _reformat_messages(self, messages : List[Dict[str, str]]):
"""
Reformat messages for Gemini.

Args:
messages: The list of messages provided in the request.

Returns:
list: The list of messages in the required format.
"""
new_messages = []

for message in messages:
if message["role"] == "system":
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]

else:
content = message["content"]

new_messages.append({"parts": content,
"role": "model" if message["role"] == "model" else "user"})

return new_messages

def _reformat_tools(self, tools: Optional[List[Dict]]):
"""
Reformat tools for Gemini.

Args:
tools: The list of tools provided in the request.

Returns:
list: The list of tools in the required format.
"""
if tools:
new_tools = []

for tool in tools:
func = tool["function"].copy()
func["parameters"].pop("additionalProperties", None)
Dev-Khant marked this conversation as resolved.
Show resolved Hide resolved
new_tools.append({"function_declarations":[func]})

return new_tools
else:
return None

def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Gemini.

Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format for the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".

Returns:
str: The generated response.
"""

params = {
"temperature": self.config.temperature,
"max_output_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}

if response_format:
params["response_mime_type"] = "application/json"
params["response_schema"] = list[response_format]
if tool_choice:
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
})

response = self.client.generate_content(contents = self._reformat_messages(messages),
tools = self._reformat_tools(tools),
generation_config = genai.GenerationConfig(**params),
tool_config = tool_config)

return self._parse_response(response, tools)
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LlmFactory:
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
"anthropic": "mem0.llms.anthropic.AnthropicLLM",
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM",
"gemini": "mem0.llms.gemini.GeminiLLM",
}

@classmethod
Expand Down
Loading
Loading