Skip to content

Commit

Permalink
Merge pull request #29 from allegro/enable-passing-gemini-safety-sett…
Browse files Browse the repository at this point in the history
…ings

Enable passing gemini safety settings
  • Loading branch information
megatron6000 authored Jul 1, 2024
2 parents 4cda7ac + 3d939c2 commit 7461c6c
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 57 deletions.
5 changes: 4 additions & 1 deletion allms/domain/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional

from langchain_google_vertexai import HarmBlockThreshold, HarmCategory

from allms.defaults.vertex_ai import GeminiModelDefaults, PalmModelDefaults

Expand All @@ -26,6 +28,7 @@ class VertexAIConfiguration:
cloud_location: str
palm_model_name: Optional[str] = PalmModelDefaults.GCP_MODEL_NAME
gemini_model_name: Optional[str] = GeminiModelDefaults.GCP_MODEL_NAME
gemini_safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None


class VertexAIModelGardenConfiguration(VertexAIConfiguration):
Expand Down
3 changes: 3 additions & 0 deletions allms/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Type

from allms.domain.configuration import HarmBlockThreshold, HarmCategory
from allms.domain.enumerables import AvailableModels
from allms.models.abstract import AbstractModel
from allms.models.azure_llama2 import AzureLlama2Model
Expand All @@ -16,6 +17,8 @@
"VertexAIPalmModel",
"VertexAIGeminiModel",
"VertexAIGemmaModel",
"HarmCategory",
"HarmBlockThreshold",
"get_available_models"
]

Expand Down
2 changes: 1 addition & 1 deletion allms/models/vertexai_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Any, Dict

from google.cloud.aiplatform.models import Prediction
from langchain_community.llms.vertexai import VertexAI, VertexAIModelGarden
from langchain_google_vertexai import VertexAI, VertexAIModelGarden
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun
from langchain_core.outputs import LLMResult, Generation
from pydash import chain
Expand Down
3 changes: 2 additions & 1 deletion allms/models/vertexai_gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asyncio import AbstractEventLoop
from langchain_community.llms.vertexai import VertexAI
from langchain_google_vertexai import VertexAI
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
Expand Down Expand Up @@ -44,6 +44,7 @@ def _create_llm(self) -> VertexAI:
temperature=self._temperature,
top_p=self._top_p,
top_k=self._top_k,
safety_settings=self._config.gemini_safety_settings,
verbose=self._verbose,
project=self._config.cloud_project,
location=self._config.cloud_location
Expand Down
2 changes: 1 addition & 1 deletion allms/models/vertexai_gemma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import AbstractEventLoop

from langchain_community.llms.vertexai import VertexAIModelGarden
from langchain_google_vertexai import VertexAIModelGarden
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
Expand Down
2 changes: 1 addition & 1 deletion allms/models/vertexai_palm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asyncio import AbstractEventLoop
from langchain_community.llms.vertexai import VertexAI
from langchain_google_vertexai import VertexAI
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
Expand Down
84 changes: 53 additions & 31 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "allms"
version = "1.0.6"
version = "1.0.7"
description = ""
authors = ["Allegro Opensource <[email protected]>"]
readme = "README.md"
Expand All @@ -9,11 +9,12 @@ packages = [{include = "allms"}]
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
fsspec = "^2023.6.0"
google-cloud-aiplatform = "1.38.0"
google-cloud-aiplatform = "^1.47.0"
pydash = "^7.0.6"
transformers = "^4.34.1"
pydantic = "1.10.13"
langchain = "^0.1.8"
langchain-google-vertexai = "1.0.4"
aioresponses = "^0.7.6"
tiktoken = "^0.6.0"
openai = "^0.27.8"
Expand Down
26 changes: 18 additions & 8 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from allms.constants.input_data import IODataConstants
from allms.domain.configuration import VertexAIConfiguration
from allms.domain.prompt_dto import KeywordsOutputClass
from allms.models.vertexai_gemini import VertexAIGeminiModel
from allms.models import VertexAIGeminiModel, HarmBlockThreshold, HarmCategory
from allms.utils import io_utils
from tests.conftest import AzureOpenAIEnv

Expand Down Expand Up @@ -148,19 +148,29 @@ def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, m
)
])

def test_gemini_version_is_passed_to_model(self):
def test_gemini_specific_args_are_passed_to_model(self):
# GIVEN
gemini_model_name = "gemini-model-name"
gemini_safety_settings = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
model_config = VertexAIConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1",
gemini_model_name="gemini-model-name"
)
cloud_project="dummy-project-id",
cloud_location="us-central1",
gemini_model_name=gemini_model_name,
gemini_safety_settings=gemini_safety_settings
)

# WHEN
gemini_model = VertexAIGeminiModel(config=model_config)

# WHEN
gemini_model._llm.model_name == "gemini-model-name"
# THEN
assert gemini_model._llm.model_name == gemini_model_name
assert gemini_model._llm.safety_settings == gemini_safety_settings

def test_model_times_out(
self,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_behavior_for_different_input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_exception_when_input_data_is_missing_and_prompt_contains_input_key(self
model.generate(prompt, None)

@patch("langchain.chains.base.Chain.arun")
@patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens")
@patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens")
def test_exception_when_num_prompt_tokens_larger_than_model_total_max_tokens(self, tokens_mock, chain_run_mock, models):
# GIVEN
chain_run_mock.return_value = "{}"
Expand All @@ -97,7 +97,7 @@ def test_exception_when_num_prompt_tokens_larger_than_model_total_max_tokens(sel
assert "Value Error has occurred: Prompt is too long" in response.error

@patch("langchain.chains.base.Chain.arun")
@patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens")
@patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens")
def test_whether_curly_brackets_are_not_breaking_the_prompt(self, tokens_mock, chain_run_mock, models):
# GIVEN
chain_run_mock.return_value = "{}"
Expand All @@ -115,7 +115,7 @@ def test_whether_curly_brackets_are_not_breaking_the_prompt(self, tokens_mock, c
assert response.response is not None

@patch("langchain.chains.base.Chain.arun")
@patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens")
@patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens")
def test_warning_when_num_prompt_tokens_plus_max_output_tokens_larger_than_model_total_max_tokens(
self,
tokens_mock,
Expand Down
Loading

0 comments on commit 7461c6c

Please sign in to comment.