Skip to content

Commit

Permalink
Merge pull request #18 from allegro/mock-vertexai
Browse files Browse the repository at this point in the history
Mock VertexAI class for tests
  • Loading branch information
megatron6000 authored Mar 7, 2024
2 parents 9c49484 + e55e562 commit 5c973cc
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 55 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: "3.10.10"
- id: 'auth'
name: 'Authenticate to Google Cloud'
uses: 'google-github-actions/auth@v1'
with:
credentials_json: '${{ secrets.GOOGLE_CREDENTIALS }}'
- name: Install dependencies
run: |
make install-poetry
Expand Down
113 changes: 63 additions & 50 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import typing
from dataclasses import dataclass
from unittest.mock import patch

import pytest
from aioresponses import aioresponses
from langchain_community.llms.fake import FakeListLLM

from allms.domain.configuration import (
AzureOpenAIConfiguration, AzureSelfDeployedConfiguration, VertexAIConfiguration, VertexAIModelGardenConfiguration)
Expand All @@ -25,59 +27,70 @@ class GenerativeModels:
vertex_palm: typing.Optional[VertexAIPalmModel] = None


class VertexAIMock(FakeListLLM):
def __init__(self, *args, **kwargs):
super().__init__(responses=["{}"])


@pytest.fixture(scope="function")
def models():
event_loop = asyncio.new_event_loop()
return {
"azure_open_ai": AzureOpenAIModel(
config=AzureOpenAIConfiguration(
api_key="dummy_api_key",
base_url=AzureOpenAIEnv.OPENAI_API_BASE,
api_version=AzureOpenAIEnv.OPENAI_API_VERSION,
deployment=AzureOpenAIEnv.OPENAI_DEPLOYMENT_NAME,
model_name="gpt-4"
),
event_loop=event_loop
),
"vertex_palm": VertexAIPalmModel(
config=VertexAIConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1"
),
event_loop=event_loop
),
"vertex_gemini": VertexAIGeminiModel(
config=VertexAIConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1"
),
event_loop=event_loop
),
"vertex_gemma": VertexAIGemmaModel(
config=VertexAIModelGardenConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1",
endpoint_id="dummy-endpoint-id"
),
event_loop=event_loop
),
"azure_llama2": AzureLlama2Model(
config=AzureSelfDeployedConfiguration(
api_key="dummy_api_key",
endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score",
deployment="dummy_deployment_name"
),
event_loop=event_loop
),
"azure_mistral": AzureMistralModel(
config=AzureSelfDeployedConfiguration(
api_key="dummy_api_key",
endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score",
deployment="dummy_deployment_name"
),
event_loop=event_loop
)
}

with (
patch("allms.models.vertexai_palm.CustomVertexAI", VertexAIMock),
patch("allms.models.vertexai_gemini.CustomVertexAI", VertexAIMock),
patch("allms.models.vertexai_gemma.VertexAIModelGardenWrapper", VertexAIMock)
):
return {
"azure_open_ai": AzureOpenAIModel(
config=AzureOpenAIConfiguration(
api_key="dummy_api_key",
base_url=AzureOpenAIEnv.OPENAI_API_BASE,
api_version=AzureOpenAIEnv.OPENAI_API_VERSION,
deployment=AzureOpenAIEnv.OPENAI_DEPLOYMENT_NAME,
model_name="gpt-4"
),
event_loop=event_loop
),
"vertex_palm": VertexAIPalmModel(
config=VertexAIConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1"
),
event_loop=event_loop
),
"vertex_gemini": VertexAIGeminiModel(
config=VertexAIConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1"
),
event_loop=event_loop
),
"vertex_gemma": VertexAIGemmaModel(
config=VertexAIModelGardenConfiguration(
cloud_project="dummy-project-id",
cloud_location="us-central1",
endpoint_id="dummy-endpoint-id"
),
event_loop=event_loop
),
"azure_llama2": AzureLlama2Model(
config=AzureSelfDeployedConfiguration(
api_key="dummy_api_key",
endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score",
deployment="dummy_deployment_name"
),
event_loop=event_loop
),
"azure_mistral": AzureMistralModel(
config=AzureSelfDeployedConfiguration(
api_key="dummy_api_key",
endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score",
deployment="dummy_deployment_name"
),
event_loop=event_loop
)
}


@pytest.fixture
Expand Down

0 comments on commit 5c973cc

Please sign in to comment.