-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First pass at adding DALLE image generator
- Loading branch information
Showing
4 changed files
with
195 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from openai import OpenAI | ||
from openai.types.image import Image | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.utils import Secret, deserialize_secrets_inplace | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class DALLEImageGenerator: | ||
""" | ||
A component to generate images using OpenAI's DALL-E model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "dall-e-3", | ||
quality: str = "standard", | ||
size: str = "1024x1024", | ||
response_format: str = "url", | ||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), | ||
api_base_url: Optional[str] = None, | ||
organization: Optional[str] = None, | ||
timeout: Optional[float] = None, | ||
max_retries: Optional[int] = None, | ||
): | ||
""" | ||
Initialize the DALLEImageGenerator component. | ||
:param model: The model to use for image generation. Can be "dall-e-2" or "dall-e-3". | ||
:param quality: The quality of the generated image. Can be "standard" or "hd". | ||
:param size: The size of the generated images. | ||
Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. | ||
Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models. | ||
:param response_format: The format of the response. Can be "url" or "b64_json". | ||
:param api_key: The OpenAI API key to connect to OpenAI. | ||
:param api_base_url: An optional base URL. | ||
:param organization: The Organization ID, defaults to `None`. | ||
:param timeout: | ||
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable | ||
or set to 30. | ||
:param max_retries: | ||
Maximum retries to establish contact with OpenAI if it returns an internal error, if not set it is inferred | ||
from the `OPENAI_MAX_RETRIES` environment variable or set to 5. | ||
""" | ||
self.model = model | ||
self.quality = quality | ||
self.size = size | ||
self.response_format = response_format | ||
self.api_key = api_key | ||
self.api_base_url = api_base_url | ||
self.organization = organization | ||
|
||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) | ||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) | ||
|
||
self.client = None | ||
|
||
def warm_up(self) -> None: | ||
""" | ||
Warm up the OpenAI client. | ||
""" | ||
if self.client is None: | ||
self.client = OpenAI( | ||
api_key=self.api_key.resolve_value(), | ||
organization=self.organization, | ||
base_url=self.api_base_url, | ||
timeout=self.timeout, | ||
max_retries=self.max_retries, | ||
) | ||
|
||
@component.output_types(images=List[str], revised_prompt=str) | ||
def run( | ||
self, | ||
prompt: str, | ||
size: Optional[str] = None, | ||
quality: Optional[str] = None, | ||
response_format: Optional[str] = None, | ||
) -> Dict[str, str]: | ||
""" | ||
Invoke the image generation inference based on the provided prompt and generation parameters. | ||
:param prompt: The prompt to generate the image. | ||
:param size: If provided, overrides the size provided during initialization. | ||
:param quality: If provided, overrides the quality provided during initialization. | ||
:param response_format: If provided, overrides the response format provided during initialization. | ||
:returns: | ||
A dictionary containing the generated list of images and the revised prompt. | ||
Depending on the `response_format` parameter, the list of images can be URLs or base64 encoded JSON strings. | ||
The revised prompt is the prompt that was used to generate the image, if there was any revision | ||
to the prompt made by OpenAI. | ||
""" | ||
if self.client is None: | ||
raise RuntimeError( | ||
"The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'." | ||
) | ||
|
||
size = size or self.size | ||
quality = quality or self.quality | ||
response_format = response_format or self.response_format | ||
response = self.client.images.generate( | ||
model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1 | ||
) | ||
image: Image = response.data[0] | ||
image_str = "" | ||
if image.url is not None: | ||
image_str = image.url | ||
elif image.b64_json is not None: | ||
image_str = image.b64_json | ||
return {"images": [image_str], "revised_prompt": image.revised_prompt or ""} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:returns: | ||
The serialized component as a dictionary. | ||
""" | ||
return default_to_dict( # type: ignore | ||
self, | ||
model=self.model, | ||
quality=self.quality, | ||
size=self.size, | ||
response_format=self.response_format, | ||
api_key=self.api_key.to_dict(), | ||
api_base_url=self.api_base_url, | ||
organization=self.organization, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "DALLEImageGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: | ||
The dictionary representation of this component. | ||
:returns: | ||
The deserialized component instance. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
deserialize_secrets_inplace(init_params, keys=["api_key"]) | ||
return default_from_dict(cls, data) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from haystack.components.generators.openai_dalle import DALLEImageGenerator | ||
|
||
|
||
class TestDALLEImageGenerator: | ||
def test_to_dict(self) -> None: | ||
generator = DALLEImageGenerator() | ||
data = generator.to_dict() | ||
assert data == { | ||
"type": "dc_custom_component.components.generators.image_generator.DALLEImageGenerator", | ||
"init_parameters": { | ||
"model": "dall-e-3", | ||
"quality": "standard", | ||
"size": "1024x1024", | ||
"response_format": "url", | ||
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, | ||
"api_base_url": None, | ||
"organization": None, | ||
}, | ||
} | ||
|
||
def test_from_dict(self) -> None: | ||
data = { | ||
"type": "dc_custom_component.components.generators.image_generator.DALLEImageGenerator", | ||
"init_parameters": { | ||
"model": "dall-e-3", | ||
"quality": "standard", | ||
"size": "1024x1024", | ||
"response_format": "url", | ||
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, | ||
"api_base_url": None, | ||
"organization": None, | ||
}, | ||
} | ||
generator = DALLEImageGenerator.from_dict(data) | ||
assert generator.model == "dall-e-3" | ||
assert generator.quality == "standard" | ||
assert generator.size == "1024x1024" | ||
assert generator.response_format == "url" | ||
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True} |