Skip to content

Commit

Permalink
Make API keys specific to providers. Remove unused classes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634834438
  • Loading branch information
Chris Rawles authored and The android_world Authors committed May 17, 2024
1 parent 170bee4 commit 824a2b5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 151 deletions.
164 changes: 14 additions & 150 deletions android_world/agents/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
import io
import os
import time
from typing import Any, Optional, Sequence, Union
from typing import Any
import google.generativeai as genai
from google.generativeai.types import generation_types
import numpy as np
import openai
from PIL import Image
import requests
import tenacity


ERROR_CALLING_LLM = 'Error calling LLM'
Expand Down Expand Up @@ -80,14 +78,11 @@ def predict_mm(


class GeminiGcpWrapper(LlmWrapper, MultimodalLlmWrapper):
"""Gemini GCP interface.
"""Gemini GCP interface."""

Attributes:
llm: Gemini multimodal GCP client.
max_retry: Max number of retries when some error happens.
"""

RETRY_WAITING_SECONDS = 20
# As of 05/15/2024, there is a limit of 5 RPM:
# https://cloud.google.com/vertex-ai/generative-ai/docs/quotas.
TIME_BETWEEN_REQUESTS = 15

def __init__(
self,
Expand All @@ -97,9 +92,9 @@ def __init__(
temperature: float = 0.0,
top_p: float = 0.95,
):
if 'API_KEY' not in os.environ:
if 'GCP_API_KEY' not in os.environ:
raise RuntimeError('GCP API key not set.')
genai.configure(api_key=os.environ['API_KEY'])
genai.configure(api_key=os.environ['GCP_API_KEY'])
if model_name is None:
model_name = (
'gemini-1.0-pro-vision-latest'
Expand All @@ -116,6 +111,7 @@ def __init__(
max_retry = 3
print('Max_retry must be positive. Reset it to 3')
self.max_retry = min(max_retry, 5)
self.time_since_last_request = time.time() - self.TIME_BETWEEN_REQUESTS

def predict(
self,
Expand All @@ -127,16 +123,17 @@ def predict_mm(
self, text_prompt: str, images: list[np.ndarray]
) -> tuple[str, Any]:
counter = self.max_retry
wait_seconds = self.RETRY_WAITING_SECONDS
while counter > 0:
try:
time_since_last_request = time.time() - self.time_since_last_request
if time_since_last_request < self.TIME_BETWEEN_REQUESTS:
time.sleep(self.TIME_BETWEEN_REQUESTS - time_since_last_request)
output = self.llm.generate_content(
[text_prompt] + [Image.fromarray(image) for image in images]
)
self.time_since_last_request = time.time()
return output.text, output
except Exception as e: # pylint: disable=broad-exception-caught
time.sleep(wait_seconds)
wait_seconds *= 2
counter -= 1
print('Error calling LLM, will retry soon...')
print(e)
Expand All @@ -162,9 +159,9 @@ def __init__(
max_retry: int = 3,
temperature: float = 0.0,
):
if 'API_KEY' not in os.environ:
if 'OPENAI_API_KEY' not in os.environ:
raise RuntimeError('OpenAI API key not set.')
self.openai_api_key = os.environ['API_KEY']
self.openai_api_key = os.environ['OPENAI_API_KEY']
if max_retry <= 0:
max_retry = 3
print('Max_retry must be positive. Reset it to 3')
Expand Down Expand Up @@ -237,136 +234,3 @@ def predict_mm(
print('Error calling LLM, will retry soon...')
print(e)
return ERROR_CALLING_LLM, None


class ClaudeWrapper(LlmWrapper, MultimodalLlmWrapper):
"""Claude 3 wrapper for both text-only and multimodal model.
Attributes:
claude_api_key: The class gets the Claude from environment.
max_retry: Max number of retries when some error happens.
temperature: The temperature parameter in LLM to control result stability.
"""

RETRY_WAITING_SECONDS = 20

def __init__(
self,
max_retry: int = 3,
temperature: float = 0.0,
):
if 'API_KEY' not in os.environ:
raise RuntimeError('ClaudeAPI key not set.')
self.claude_api_key = os.environ['API_KEY']
if max_retry <= 0:
max_retry = 3
print('Max_retry must be positive. Reset it to 3')
self.max_retry = min(max_retry, 5)
self.temperature = temperature

@classmethod
def encode_image(cls, image: np.ndarray) -> str:
return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')

def predict(
self,
text_prompt: str,
) -> tuple[str, Any]:
return self.predict_mm(text_prompt, [])

def predict_mm(
self, text_prompt: str, images: list[np.ndarray]
) -> tuple[str, Any]:
headers = {
'x-api-key': self.claude_api_key,
'anthropic-version': '2023-06-01',
'Content-Type': 'application/json',
}

payload = {
'model': 'claude-3-opus-20240229',
'max_tokens': 1024,
'temperature': self.temperature,
'messages': [{
'role': 'user',
'content': [
{'type': 'text', 'text': text_prompt},
],
}],
}

# Claude supports multiple images, just need to insert them in the content
# list.
for image in images:
payload['messages'][0]['content'].append({
'type': 'image',
'source': {
'type': 'base64',
'media_type': 'image/jpeg',
'data': self.encode_image(image),
},
})

counter = self.max_retry
wait_seconds = self.RETRY_WAITING_SECONDS
while counter > 0:
try:
response = requests.post(
'https://api.anthropic.com/v1/messages',
headers=headers,
json=payload,
)
if response.ok and 'content' in response.json():
return response.json()['content'][0]['text'], response
print(
'Error calling OpenAI API with error message: '
+ response.json()['error']['message']
)
time.sleep(wait_seconds)
wait_seconds *= 2
except Exception as e: # pylint: disable=broad-exception-caught
# Want to catch all exceptions happened during LLM calls.
time.sleep(wait_seconds)
wait_seconds *= 2
counter -= 1
print('Error calling LLM, will retry soon...')
print(e)
return ERROR_CALLING_LLM, None


@tenacity.retry(
wait=tenacity.wait_random_exponential(min=1, max=60),
stop=tenacity.stop_after_attempt(6),
)
def chat_completion_with_backoff(**kwargs: Any) -> Any:
return openai.ChatCompletion.create(**kwargs)


class OpenAI:
"""Makes inference through OpenAI api."""

def __init__(
self,
model: str,
openai_api_key: Optional[str] = None,
):
self._model = model
if not openai_api_key:
openai_api_key = os.getenv('OPENAI_API_KEY')
openai.api_key = openai_api_key

def predict(
self,
input_text: Union[str, Sequence[str]],
temperature: Optional[float] = None,
) -> str:
"""Calls OpenAI and returns the text output."""
# Generate a response from the model
response = chat_completion_with_backoff(
model=self._model,
messages=[{'role': 'user', 'content': input_text}],
temperature=temperature,
)
output = response.choices[0].message['content']

return output
3 changes: 2 additions & 1 deletion android_world/agents/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def setUp(self):
super().setUp()
self.mock_post = mock.patch.object(requests, "post").start()
self.mock_sleep = mock.patch.object(time, "sleep").start()
os.environ["API_KEY"] = "fake_api_key"
os.environ["OPENAI_API_KEY"] = "fake_api_key"
os.environ["GCP_API_KEY"] = "fake_api_key"

@mock.patch.object(genai.GenerativeModel, "generate_content")
def test_gemini_gcp(self, mock_generate_content):
Expand Down

0 comments on commit 824a2b5

Please sign in to comment.