Skip to content

Commit

Permalink
First pass at adding DALLE image generator
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Oct 9, 2024
1 parent e7bfd80 commit ebfaf7b
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/pydoc/config/generators_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loaders:
"hugging_face_local",
"hugging_face_api",
"openai",
"openai_dalle",
"chat/azure",
"chat/hugging_face_local",
"chat/hugging_face_api",
Expand Down
9 changes: 8 additions & 1 deletion haystack/components/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,12 @@
from haystack.components.generators.azure import AzureOpenAIGenerator
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.components.generators.hugging_face_api import HuggingFaceAPIGenerator
from haystack.components.generators.openai_dalle import DALLEImageGenerator

__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceAPIGenerator", "OpenAIGenerator", "AzureOpenAIGenerator"]
__all__ = [
"HuggingFaceLocalGenerator",
"HuggingFaceAPIGenerator",
"OpenAIGenerator",
"AzureOpenAIGenerator",
"DALLEImageGenerator",
]
147 changes: 147 additions & 0 deletions haystack/components/generators/openai_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
from typing import Any, Dict, List, Optional

from openai import OpenAI
from openai.types.image import Image

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace

logger = logging.getLogger(__name__)


@component
class DALLEImageGenerator:
"""
A component to generate images using OpenAI's DALL-E model.
"""

def __init__(
self,
model: str = "dall-e-3",
quality: str = "standard",
size: str = "1024x1024",
response_format: str = "url",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Initialize the DALLEImageGenerator component.
:param model: The model to use for image generation. Can be "dall-e-2" or "dall-e-3".
:param quality: The quality of the generated image. Can be "standard" or "hd".
:param size: The size of the generated images.
Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2.
Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
:param response_format: The format of the response. Can be "url" or "b64_json".
:param api_key: The OpenAI API key to connect to OpenAI.
:param api_base_url: An optional base URL.
:param organization: The Organization ID, defaults to `None`.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable
or set to 30.
:param max_retries:
Maximum retries to establish contact with OpenAI if it returns an internal error, if not set it is inferred
from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.model = model
self.quality = quality
self.size = size
self.response_format = response_format
self.api_key = api_key
self.api_base_url = api_base_url
self.organization = organization

self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = None

def warm_up(self) -> None:
"""
Warm up the OpenAI client.
"""
if self.client is None:
self.client = OpenAI(
api_key=self.api_key.resolve_value(),
organization=self.organization,
base_url=self.api_base_url,
timeout=self.timeout,
max_retries=self.max_retries,
)

@component.output_types(images=List[str], revised_prompt=str)
def run(
self,
prompt: str,
size: Optional[str] = None,
quality: Optional[str] = None,
response_format: Optional[str] = None,
) -> Dict[str, str]:
"""
Invoke the image generation inference based on the provided prompt and generation parameters.
:param prompt: The prompt to generate the image.
:param size: If provided, overrides the size provided during initialization.
:param quality: If provided, overrides the quality provided during initialization.
:param response_format: If provided, overrides the response format provided during initialization.
:returns:
A dictionary containing the generated list of images and the revised prompt.
Depending on the `response_format` parameter, the list of images can be URLs or base64 encoded JSON strings.
The revised prompt is the prompt that was used to generate the image, if there was any revision
to the prompt made by OpenAI.
"""
if self.client is None:
raise RuntimeError(
"The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

size = size or self.size
quality = quality or self.quality
response_format = response_format or self.response_format
response = self.client.images.generate(
model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1
)
image: Image = response.data[0]
image_str = ""
if image.url is not None:
image_str = image.url
elif image.b64_json is not None:
image_str = image.b64_json
return {"images": [image_str], "revised_prompt": image.revised_prompt or ""}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:returns:
The serialized component as a dictionary.
"""
return default_to_dict( # type: ignore
self,
model=self.model,
quality=self.quality,
size=self.size,
response_format=self.response_format,
api_key=self.api_key.to_dict(),
api_base_url=self.api_base_url,
organization=self.organization,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DALLEImageGenerator":
"""
Deserialize this component from a dictionary.
:param data:
The dictionary representation of this component.
:returns:
The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, keys=["api_key"])
return default_from_dict(cls, data) # type: ignore
39 changes: 39 additions & 0 deletions test/components/generators/test_openai_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from haystack.components.generators.openai_dalle import DALLEImageGenerator


class TestDALLEImageGenerator:
def test_to_dict(self) -> None:
generator = DALLEImageGenerator()
data = generator.to_dict()
assert data == {
"type": "dc_custom_component.components.generators.image_generator.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-3",
"quality": "standard",
"size": "1024x1024",
"response_format": "url",
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"api_base_url": None,
"organization": None,
},
}

def test_from_dict(self) -> None:
data = {
"type": "dc_custom_component.components.generators.image_generator.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-3",
"quality": "standard",
"size": "1024x1024",
"response_format": "url",
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"api_base_url": None,
"organization": None,
},
}
generator = DALLEImageGenerator.from_dict(data)
assert generator.model == "dall-e-3"
assert generator.quality == "standard"
assert generator.size == "1024x1024"
assert generator.response_format == "url"
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}

0 comments on commit ebfaf7b

Please sign in to comment.