Skip to content

Commit

Permalink
compare add pixart
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Jul 5, 2024
1 parent 42bf351 commit c2a6679
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions discord_tron_master/cogs/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,25 @@ def retrieve_vlm_caption(image_url) -> str:
client.close()
return str(result)

def generate_cascade_via_hub(prompt: str, user_id: int = None):
async def generate_pixart_via_hub(prompt:str, user_id: int = None):
from gradio_client import Client

client = Client("ptx0/PixArt-900M")
result = client.predict(
prompt="Hello!!",
guidance_scale=3.4,
num_inference_steps=28,
resolution="1152x960",
negative_prompt="underexposed, blurry, ugly, washed-out",
api_name="/predict"
)
# close the connection
client.close()
split_pieces = result.split('/')
return f"https://ptx0-pixart-900m.hf.space/file=/tmp/gradio/{split_pieces[-2]}/image.webp"


async def generate_cascade_via_hub(prompt: str, user_id: int = None):
from gradio_client import Client

client = Client("multimodalart/stable-cascade")
Expand All @@ -40,7 +58,7 @@ def generate_cascade_via_hub(prompt: str, user_id: int = None):
return f"https://multimodalart-stable-cascade.hf.space/file=/tmp/gradio/{split_pieces[-2]}/image.png"


def generate_sd3_via_hub(prompt: str, model: str = None, user_id: int = None):
async def generate_sd3_via_hub(prompt: str, model: str = None, user_id: int = None):
from gradio_client import Client
user_config = AppConfig().get_user_config(user_id=user_id)

Expand All @@ -63,7 +81,7 @@ def generate_sd3_via_hub(prompt: str, model: str = None, user_id: int = None):
return f"https://ameerazam08-sd-3-medium-gpu.hf.space/file=/tmp/gradio/{split_pieces[-2]}/image.webp"


def generate_terminus_via_hub(prompt: str, model: str = "velocity", user_id: int = None):
async def generate_terminus_via_hub(prompt: str, model: str = "velocity", user_id: int = None):
from gradio_client import Client
available_models = {
"velocity": "ptx0/ptx0-terminus-xl-velocity-v2",
Expand Down Expand Up @@ -150,16 +168,23 @@ async def generate_image(ctx, prompt, user_id: int = None, extra_image: dict = N
# Retrieve https://pollinations.ai/prompt/{prompt}?seed={seed}&width={user_config['resolution']['width']}&height={user_config['resolution']['height']}
# pollinations_image = Image.open(BytesIO(requests.get(f"https://pollinations.ai/prompt/{prompt}?seed={user_config['seed']}&width={user_config['resolution']['width']}&height={user_config['resolution']['height']}").content))
try:
pollinations_image = Image.open(BytesIO(requests.get(generate_sd3_via_hub(prompt, user_id=user_id)).content))
pollinations_image = Image.open(BytesIO(requests.get(await generate_sd3_via_hub(prompt, user_id=user_id)).content))
except:
pollinations_image = Image.new('RGB', dalle_image.size, (0, 0, 0))
try:
extra_image = {
"label": "Terminus XL Velocity V2 (WIP)",
"data": Image.open(BytesIO(requests.get(generate_terminus_via_hub(prompt, user_id=user_id)).content))
"data": Image.open(BytesIO(requests.get(await generate_terminus_via_hub(prompt, user_id=user_id)).content))
}
except:
extra_image = None
try:
extra_image_2 = {
"label": "PixArt 900M (WIP)",
"data": Image.open(BytesIO(requests.get(await generate_pixart_via_hub(prompt, user_id=user_id)).content))
}
except:
extra_image_2 = None
draw = ImageDraw.Draw(pollinations_image)
draw.text((10, 10), "SD3 2B", (255, 255, 255), font=font, stroke_fill=(0,0,0), stroke_width=4)

Expand All @@ -171,6 +196,11 @@ async def generate_image(ctx, prompt, user_id: int = None, extra_image: dict = N
extra_image_vertical_position = int((height - extra_image["data"].size[1]) / 2)
extra_image_position = (new_width, extra_image_vertical_position)
new_width = new_width + extra_image["data"].size[0]
if extra_image_2 is not None:
extra_image_vertical_position = int((height - extra_image_2["data"].size[1]) / 2)
extra_image_position = (new_width, extra_image_vertical_position)
new_width = new_width + extra_image_2["data"].size[0]

new_image = Image.new('RGB', (new_width, height))
new_image.paste(pollinations_image, (0, 0))
new_image.paste(dalle_image, (width, 0))
Expand All @@ -180,6 +210,11 @@ async def generate_image(ctx, prompt, user_id: int = None, extra_image: dict = N
draw.text((10, 10), extra_image["label"], (255, 255, 255), font=font, stroke_fill=(0,0,0), stroke_width=4)
width, height = extra_image["data"].size
new_image.paste(extra_image["data"], extra_image_position)
if extra_image_2 is not None:
draw = ImageDraw.Draw(extra_image_2["data"])
draw.text((10, 10), extra_image_2["label"], (255, 255, 255), font=font, stroke_fill=(0,0,0), stroke_width=4)
width, height = extra_image_2["data"].size
new_image.paste(extra_image_2["data"], extra_image_position)
# Save the new image to a BytesIO object.
output = BytesIO()
new_image.save(output, format="PNG")
Expand Down

0 comments on commit c2a6679

Please sign in to comment.