Skip to content

Commit

Permalink
fix: Improve Gemini client error handling and add tests (stitionai#530)
Browse files Browse the repository at this point in the history
- Add better error messages for API key configuration
- Add comprehensive test coverage
- Update google-generativeai version requirement
- Add proper logging for debugging

Fixes stitionai#530

Co-Authored-By: Erkin Alp Güney <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and erkinalp committed Dec 20, 2024
1 parent 3b98ed3 commit 16e501d
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 24 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pytest-playwright
tiktoken
ollama
openai
anthropic
google-generativeai
anthropic>=0.8.0
google-generativeai>=0.3.0
sqlmodel
keybert
GitPython
Expand Down
70 changes: 48 additions & 22 deletions src/llm/gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,58 @@
from google.generativeai.types import HarmCategory, HarmBlockThreshold

from src.config import Config
from src.logger import Logger

logger = Logger()
config = Config()

class Gemini:
def __init__(self):
config = Config()
api_key = config.get_gemini_api_key()
genai.configure(api_key=api_key)
if not api_key:
error_msg = ("Gemini API key not found in configuration. "
"Please add your Gemini API key to config.toml under [API_KEYS] "
"section as GEMINI = 'your-api-key'")
logger.error(error_msg)
raise ValueError(error_msg)
try:
genai.configure(api_key=api_key)
logger.info("Successfully initialized Gemini client")
except Exception as e:
error_msg = f"Failed to configure Gemini client: {str(e)}"
logger.error(error_msg)
raise ValueError(error_msg)

def inference(self, model_id: str, prompt: str) -> str:
config = genai.GenerationConfig(temperature=0)
model = genai.GenerativeModel(model_id, generation_config=config)
# Set safety settings for the request
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
# You can adjust other categories as needed
}
response = model.generate_content(prompt, safety_settings=safety_settings)
try:
# Check if the response contains text
return response.text
except ValueError:
# If the response doesn't contain text, check if the prompt was blocked
print("Prompt feedback:", response.prompt_feedback)
# Also check the finish reason to see if the response was blocked
print("Finish reason:", response.candidates[0].finish_reason)
# If the finish reason was SAFETY, the safety ratings have more details
print("Safety ratings:", response.candidates[0].safety_ratings)
# Handle the error or return an appropriate message
return "Error: Unable to generate content Gemini API"
logger.info(f"Initializing Gemini model: {model_id}")
config = genai.GenerationConfig(temperature=0)
model = genai.GenerativeModel(model_id, generation_config=config)

safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
}

logger.info("Generating response with Gemini")
response = model.generate_content(prompt, safety_settings=safety_settings)

try:
if response.text:
logger.info("Successfully generated response")
return response.text
else:
error_msg = f"Empty response from Gemini model {model_id}"
logger.error(error_msg)
raise ValueError(error_msg)
except ValueError:
logger.error("Failed to get response text")
logger.error(f"Prompt feedback: {response.prompt_feedback}")
logger.error(f"Finish reason: {response.candidates[0].finish_reason}")
logger.error(f"Safety ratings: {response.candidates[0].safety_ratings}")
return "Error: Unable to generate content with Gemini API"

except Exception as e:
error_msg = f"Error during Gemini inference: {str(e)}"
logger.error(error_msg)
raise ValueError(error_msg)
77 changes: 77 additions & 0 deletions tests/test_gemini_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Tests for Gemini client implementation.
"""
import pytest
from unittest.mock import MagicMock, patch
from src.llm.gemini_client import Gemini

@pytest.fixture
def mock_config():
with patch('src.llm.gemini_client.config') as mock:
mock.get_gemini_api_key.return_value = "test-api-key"
yield mock

@pytest.fixture
def mock_genai():
with patch('src.llm.gemini_client.genai') as mock:
yield mock

@pytest.fixture
def gemini_client(mock_config, mock_genai):
return Gemini()

def test_init_with_api_key(mock_config, mock_genai):
"""Test client initialization with API key."""
client = Gemini()
mock_genai.configure.assert_called_once_with(api_key="test-api-key")

def test_init_without_api_key(mock_config, mock_genai):
"""Test client initialization without API key."""
mock_config.get_gemini_api_key.return_value = None
with pytest.raises(ValueError, match="Gemini API key not found in configuration"):
Gemini()

def test_init_config_failure(mock_config, mock_genai):
"""Test handling of configuration failure."""
mock_genai.configure.side_effect = Exception("Test error")
with pytest.raises(ValueError, match="Failed to configure Gemini client: Test error"):
Gemini()

def test_inference_success(mock_genai, gemini_client):
"""Test successful text generation."""
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_model.generate_content.return_value = mock_response
mock_genai.GenerativeModel.return_value = mock_model

response = gemini_client.inference("gemini-pro", "Test prompt")
assert response == "Generated response"
mock_model.generate_content.assert_called_once_with("Test prompt", safety_settings={
mock_genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: mock_genai.types.HarmBlockThreshold.BLOCK_NONE,
mock_genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: mock_genai.types.HarmBlockThreshold.BLOCK_NONE,
})

def test_inference_empty_response(mock_genai, gemini_client):
"""Test handling of empty response."""
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = None
mock_model.generate_content.return_value = mock_response
mock_genai.GenerativeModel.return_value = mock_model

with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"):
gemini_client.inference("gemini-pro", "Test prompt")

def test_inference_error(mock_genai, gemini_client):
"""Test handling of inference error."""
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Test error")
mock_genai.GenerativeModel.return_value = mock_model

with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"):
gemini_client.inference("gemini-pro", "Test prompt")

def test_str_representation(gemini_client):
"""Test string representation."""
assert str(gemini_client) == "Gemini"

0 comments on commit 16e501d

Please sign in to comment.