Skip to content

Commit

Permalink
Merge pull request #26 from allegro/gemini-configurable-version
Browse files Browse the repository at this point in the history
Made Gemini and Palm versions configurable
  • Loading branch information
alicja-raczkowska authored May 9, 2024
2 parents 7a863d9 + b901000 commit 765e242
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 13 deletions.
20 changes: 18 additions & 2 deletions allms/domain/configuration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dataclasses import dataclass
from typing import Optional

from allms.defaults.vertex_ai import GeminiModelDefaults, PalmModelDefaults


@dataclass
Expand All @@ -21,8 +24,21 @@ class AzureSelfDeployedConfiguration:
class VertexAIConfiguration:
cloud_project: str
cloud_location: str
palm_model_name: Optional[str] = PalmModelDefaults.GCP_MODEL_NAME
gemini_model_name: Optional[str] = GeminiModelDefaults.GCP_MODEL_NAME


@dataclass
class VertexAIModelGardenConfiguration(VertexAIConfiguration):
endpoint_id: str
def __init__(
self,
cloud_project: str,
cloud_location: str,
endpoint_id: str
):
super().__init__(
cloud_project=cloud_project,
cloud_location=cloud_location,
palm_model_name=None,
gemini_model_name=None
)
self.endpoint_id = endpoint_id
2 changes: 1 addition & 1 deletion allms/models/vertexai_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(

def _create_llm(self) -> VertexAI:
return CustomVertexAI(
model_name=GeminiModelDefaults.GCP_MODEL_NAME,
model_name=self._config.gemini_model_name,
max_output_tokens=self._max_output_tokens,
temperature=self._temperature,
top_p=self._top_p,
Expand Down
2 changes: 1 addition & 1 deletion allms/models/vertexai_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(

def _create_llm(self) -> VertexAI:
return CustomVertexAI(
model_name=PalmModelDefaults.GCP_MODEL_NAME,
model_name=self._config.palm_model_name,
max_output_tokens=self._max_output_tokens,
temperature=self._temperature,
top_p=self._top_p,
Expand Down
7 changes: 5 additions & 2 deletions docs/api/models/vertexai_gemini_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ is not provided, the length of this list is equal 1, and the first element is th
```python
VertexAIConfiguration(
cloud_project: str,
cloud_location: str
cloud_location: str,
gemini_model_name: str
)
```
#### Parameters
- `cloud_project` (`str`): The GCP project to use when making Vertex API calls.
- `cloud_location` (`str`): The region to use when making API calls.
- `gemini_model_name` (`str`): The specific Gemini version you want to use. Default value: `gemini-pro` (i.e. Gemini 1.0).

---

Expand All @@ -74,7 +76,8 @@ from allms.domain.configuration import VertexAIConfiguration

configuration = VertexAIConfiguration(
cloud_project="<GCP_PROJECT_ID>",
cloud_location="<MODEL_REGION>"
cloud_location="<MODEL_REGION>",
gemini_model_name="<MODEL_NAME>"
)

vertex_model = VertexAIGeminiModel(config=configuration)
Expand Down
7 changes: 5 additions & 2 deletions docs/api/models/vertexai_palm_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ is not provided, the length of this list is equal 1, and the first element is th
```python
VertexAIConfiguration(
cloud_project: str,
cloud_location: str
cloud_location: str,
palm_model_name: str
)
```
#### Parameters
- `cloud_project` (`str`): The GCP project to use when making Vertex API calls.
- `cloud_location` (`str`): The region to use when making API calls.
- `palm_model_name` (`str`): The specific Palm version you want to use. Default value: `text-bison@001`.

---

Expand All @@ -74,7 +76,8 @@ from allms.domain.configuration import VertexAIConfiguration

configuration = VertexAIConfiguration(
cloud_project="<GCP_PROJECT_ID>",
cloud_location="<MODEL_REGION>"
cloud_location="<MODEL_REGION>",
palm_model_name="<MODEL_NAME>"
)

vertex_model = VertexAIPalmModel(config=configuration)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "allms"
version = "1.0.3"
version = "1.0.4"
description = ""
authors = ["Allegro Opensource <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 0 additions & 4 deletions tests/test_available_models_added_to_all.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# import html.entities
# from unittest.mock import patch
#
# import pytest
from allms import models


Expand Down
16 changes: 16 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, SystemMessagePromptTemplate

from allms.constants.input_data import IODataConstants
from allms.domain.configuration import VertexAIConfiguration
from allms.domain.prompt_dto import KeywordsOutputClass
from allms.models.vertexai_gemini import VertexAIGeminiModel
from allms.utils import io_utils
from tests.conftest import AzureOpenAIEnv

Expand Down Expand Up @@ -146,6 +148,20 @@ def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, m
)
])

def test_gemini_version_is_passed_to_model(self):
# GIVEN
model_config = VertexAIConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1",
gemini_model_name="gemini-model-name"
)

# WHEN
gemini_model = VertexAIGeminiModel(config=model_config)

# WHEN
gemini_model._llm.model_name == "gemini-model-name"

def test_model_times_out(
self,
mock_aioresponse,
Expand Down

0 comments on commit 765e242

Please sign in to comment.