Skip to content

Commit

Permalink
Adds streamlit
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Dec 14, 2023
1 parent e723d6d commit 885d27b
Show file tree
Hide file tree
Showing 8 changed files with 1,641 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import datetime
import logging
from typing import Any, Dict, List, Optional
from typing import IO, Any, Dict, List, Optional, Union

import boto3

from hamilton.function_modifiers import config

Expand All @@ -21,10 +23,16 @@ def openai_client() -> openai.OpenAI:
return openai.OpenAI()


def _encode_image(image_path):
def _encode_image(image_path_or_file: Union[str, IO], ext: str):
"""Helper fn to return a base-64 encoded image"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
file_like_object = (
image_path_or_file
if hasattr(image_path_or_file, "read")
else open(image_path_or_file, "rb")
)
with file_like_object as image_file:
out = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:image/{ext};base64,{out}"


def core_prompt() -> str:
Expand All @@ -36,19 +44,26 @@ def processed_image_url(image_url: str) -> str:
"""Returns a processed image URL -- base-64 encoded if it is local,
otherwise remote if it is a URL"""
is_local = urllib.parse.urlparse(image_url).scheme == ""
is_s3 = urllib.parse.urlparse(image_url).scheme == "s3"
ext = image_url.split(".")[-1]
if is_local:
# In this case we load up/encode
encoded_image = _encode_image(image_url)
extension = image_url.split(".")[-1]
return f"data:image/{extension};base64,{encoded_image}"
return _encode_image(image_url, ext)
elif is_s3:
# In this case we just return the URL
client = boto3.client("s3")
bucket = urllib.parse.urlparse(image_url).netloc
key = urllib.parse.urlparse(image_url).path[1:]
obj = client.get_object(Bucket=bucket, Key=key)
return _encode_image(obj["Body"], ext)
# In this case we just return the URL
return image_url


def prompt(
core_prompt: str,
additional_prompt: Optional[str] = None,
descriptiveness: Optional[str] = None,
def caption_prompt(
core_prompt: str,
additional_prompt: Optional[str] = None,
descriptiveness: Optional[str] = None,
) -> str:
"""Returns the prompt used to describe an image"""
out = core_prompt
Expand All @@ -64,11 +79,11 @@ def prompt(


def generated_caption(
processed_image_url: str,
prompt: str,
openai_client: openai.OpenAI,
model: str = DEFAULT_MODEL,
max_tokens: int = 2000,
processed_image_url: str,
caption_prompt: str,
openai_client: openai.OpenAI,
model: str = DEFAULT_MODEL,
max_tokens: int = 2000,
) -> str:
"""Returns the response to a given chat"""
messages = [
Expand All @@ -77,7 +92,7 @@ def generated_caption(
"content": [
{
"type": "text",
"text": prompt,
"text": caption_prompt,
},
{"type": "image_url", "image_url": {"url": f"{processed_image_url}"}},
],
Expand All @@ -93,13 +108,13 @@ def generated_caption(

@config.when(include_embeddings=True)
def caption_embeddings(
client: openai.OpenAI,
embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL,
generated_caption: str = None,
openai_client: openai.OpenAI,
embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL,
generated_caption: str = None,
) -> List[float]:
"""Returns the embeddings for a generated caption"""
data = (
client.embeddings.create(
openai_client.embeddings.create(
input=[generated_caption],
model=embeddings_model,
)
Expand All @@ -110,24 +125,24 @@ def caption_embeddings(


def caption_metadata(
image_url: str,
generated_caption: str,
prompt: str,
model: str = DEFAULT_MODEL,
image_url: str,
generated_caption: str,
caption_prompt: str,
model: str = DEFAULT_MODEL,
) -> dict:
"""Returns metadata for the caption portion of the workflow"""
return {
"original_image_url": image_url,
"generated_caption": generated_caption,
"caption_model": model,
"caption_prompt": prompt,
"caption_prompt": caption_prompt,
}


@config.when(include_embeddings=True)
def embeddings_metadata(
caption_embeddings: List[float],
embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL,
caption_embeddings: List[float],
embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL,
) -> dict:
"""Returns metadata for the embeddings portion of the workflow"""
return {
Expand All @@ -137,9 +152,9 @@ def embeddings_metadata(


def metadata(
embeddings_metadata: dict,
caption_metadata: Optional[dict] = None,
additional_metadata: Optional[dict] = None,
embeddings_metadata: dict,
caption_metadata: Optional[dict] = None,
additional_metadata: Optional[dict] = None,
) -> Dict[str, Any]:
"""Returns the response to a given chat"""
out = embeddings_metadata
Expand All @@ -159,7 +174,7 @@ def metadata(
from hamilton import base, driver

dr = driver.Driver(
{"include_embeddings" : True}, # CONFIG: fill as appropriate
{"include_embeddings": True}, # CONFIG: fill as appropriate
image_captioning,
adapter=base.DefaultAdapter(),
)
Expand Down
3 changes: 3 additions & 0 deletions contrib/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
boto3
openai
sf-hamilton
tenacity
16 changes: 16 additions & 0 deletions examples/LLM_Workflows/image_telephone/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Image Telephone

See the [streamlit app](image-telephone.streamlit.app) for documentation.

This example uses dataflows from the hub to do something fun with image captioning and generation.
Note that hamilton code is used rather than defined here.


# Contents

There are two files in this:

1. generate_images.ipynb
2. streamlit_app.py

The first is a notebook that generates images and captions, and the second is a streamlit app that displays them.
100 changes: 100 additions & 0 deletions examples/LLM_Workflows/image_telephone/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import dataclasses
import io
import json
import logging
from typing import Any, Collection, Dict, Type
from urllib import parse

import boto3
import requests
from PIL import Image

from hamilton.io.data_adapters import DataSaver
from hamilton.registry import register_adapter

client = boto3.client("s3")

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class JSONS3DataSaver(DataSaver):
bucket: str
key: str

def save_data(self, data: dict) -> Dict[str, Any]:
data = json.dumps(data).encode()
client.put_object(Body=data, Bucket=self.bucket, Key=self.key)

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [dict]

@classmethod
def name(cls) -> str:
return "json_s3"


def _load_image(uri: str, format: str) -> Image:
parsed = parse.urlparse(uri)
if parsed.scheme.strip() == "": # local file to upload
with open(uri, "rb") as f:
data = f.read()
elif parsed.scheme.strip() in ("https", "http"): # URL to copy over
response = requests.get(uri)
data = response.content
image = Image.open(io.BytesIO(data))
if format in ("jpeg", "jpg"): # TODO -- add more formats if they don't support it
if image.mode in ("RGBA", "P"):
image = image.convert("RGB")
return image


@dataclasses.dataclass
class ImageS3DataSaver(DataSaver):
bucket: str
key: str
format: str
# image_convert_params: Optional[Dict[str, Any]] = None

def save_data(self, data: str) -> Dict[str, Any]:
image = _load_image(data, self.format)
in_mem_file = io.BytesIO()
image.save(in_mem_file, format=self.format)
in_mem_file.seek(0)
client.put_object(Body=in_mem_file, Bucket=self.bucket, Key=self.key)
return {"key": self.key, "bucket": self.bucket}

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [str] # URL or local path

@classmethod
def name(cls) -> str:
return "image_s3"


@dataclasses.dataclass
class LocalImageSaver(DataSaver):
path: str
format: str
# image_convert_params: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)

def save_data(self, data: str) -> Dict[str, Any]:
image = _load_image(data, self.format)
image.save(self.path, format=self.format)
return {"path": self.path}

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [str] # URL or local path

@classmethod
def name(cls) -> str:
return "image"


adapters = [JSONS3DataSaver, ImageS3DataSaver, LocalImageSaver]

for adapter in adapters:
register_adapter(adapter)
Loading

0 comments on commit 885d27b

Please sign in to comment.