Skip to content

Commit

Permalink
Merge pull request #34 from allegro/adding-gcp-tokenizer-for-gemini-m…
Browse files Browse the repository at this point in the history
…odels

Added GCP tokenizer for gemini models
  • Loading branch information
riccardo-alle authored Aug 8, 2024
2 parents 64d7f5d + f09ccb9 commit 28a785d
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 28 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ ___

## Supported Models

| LLM Family | Hosting | Supported LLMs |
|-------------|---------------------|-----------------------------------------|
| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo` |
| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro` |
| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` |
| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` |
| Gemma | GCP deployment | `gemma` |
| LLM Family | Hosting | Supported LLMs |
|-------------|---------------------|------------------------------------------------------------------|
| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo`, `gpt4-o`, `gpt4-o mini` |
| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro`, `gemini-flash` |
| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` |
| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` |
| Gemma | GCP deployment | `gemma` |

* Do you already have a subscription to a Cloud Provider for any the models above? Configure
the model using your credentials and start querying!
Expand Down
2 changes: 1 addition & 1 deletion allms/defaults/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PalmModelDefaults:


class GeminiModelDefaults:
GCP_MODEL_NAME = "gemini-pro"
GCP_MODEL_NAME = "gemini-1.5-flash-001"
MODEL_TOTAL_MAX_TOKENS = 30720
MAX_OUTPUT_TOKENS = 2048
TEMPERATURE = 0.0
Expand Down
20 changes: 19 additions & 1 deletion allms/models/vertexai_gemini.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import typing
from asyncio import AbstractEventLoop

from langchain_core.prompts import ChatPromptTemplate
from langchain_google_vertexai import VertexAI
from vertexai.preview import tokenization
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
from allms.defaults.vertex_ai import GeminiModelDefaults
from allms.domain.configuration import VertexAIConfiguration
from allms.domain.input_data import InputData
from allms.models.vertexai_base import CustomVertexAI
from allms.models.abstract import AbstractModel

Expand All @@ -28,6 +33,8 @@ def __init__(
self._verbose = verbose
self._config = config

self._gcp_tokenizer = tokenization.get_tokenizer_for_model(self._config.gemini_model_name)

super().__init__(
temperature=temperature,
model_total_max_tokens=model_total_max_tokens,
Expand All @@ -48,4 +55,15 @@ def _create_llm(self) -> VertexAI:
verbose=self._verbose,
project=self._config.cloud_project,
location=self._config.cloud_location
)
)

def _get_prompt_tokens_number(self, prompt: ChatPromptTemplate, input_data: InputData) -> int:
return self._gcp_tokenizer.count_tokens(
prompt.format_prompt(**input_data.input_mappings).to_string()
).total_tokens

def _get_model_response_tokens_number(self, model_response: typing.Optional[str]) -> int:
if model_response:
return self._gcp_tokenizer.count_tokens(model_response).total_tokens
return 0

Loading

0 comments on commit 28a785d

Please sign in to comment.