diff --git a/discord_tron_master/cogs/image/__init__.py b/discord_tron_master/cogs/image/__init__.py index 6740baa..8a78e1a 100644 --- a/discord_tron_master/cogs/image/__init__.py +++ b/discord_tron_master/cogs/image/__init__.py @@ -17,7 +17,25 @@ def retrieve_vlm_caption(image_url) -> str: client.close() return str(result) -def generate_cascade_via_hub(prompt: str, user_id: int = None): +async def generate_pixart_via_hub(prompt:str, user_id: int = None): + from gradio_client import Client + + client = Client("ptx0/PixArt-900M") + result = client.predict( + prompt="Hello!!", + guidance_scale=3.4, + num_inference_steps=28, + resolution="1152x960", + negative_prompt="underexposed, blurry, ugly, washed-out", + api_name="/predict" + ) + # close the connection + client.close() + split_pieces = result.split('/') + return f"https://ptx0-pixart-900m.hf.space/file=/tmp/gradio/{split_pieces[-2]}/image.webp" + + +async def generate_cascade_via_hub(prompt: str, user_id: int = None): from gradio_client import Client client = Client("multimodalart/stable-cascade") @@ -40,7 +58,7 @@ def generate_cascade_via_hub(prompt: str, user_id: int = None): return f"https://multimodalart-stable-cascade.hf.space/file=/tmp/gradio/{split_pieces[-2]}/image.png" -def generate_sd3_via_hub(prompt: str, model: str = None, user_id: int = None): +async def generate_sd3_via_hub(prompt: str, model: str = None, user_id: int = None): from gradio_client import Client user_config = AppConfig().get_user_config(user_id=user_id) @@ -63,7 +81,7 @@ def generate_sd3_via_hub(prompt: str, model: str = None, user_id: int = None): return f"https://ameerazam08-sd-3-medium-gpu.hf.space/file=/tmp/gradio/{split_pieces[-2]}/image.webp" -def generate_terminus_via_hub(prompt: str, model: str = "velocity", user_id: int = None): +async def generate_terminus_via_hub(prompt: str, model: str = "velocity", user_id: int = None): from gradio_client import Client available_models = { "velocity": "ptx0/ptx0-terminus-xl-velocity-v2", @@ -150,16 +168,23 @@ async def generate_image(ctx, prompt, user_id: int = None, extra_image: dict = N # Retrieve https://pollinations.ai/prompt/{prompt}?seed={seed}&width={user_config['resolution']['width']}&height={user_config['resolution']['height']} # pollinations_image = Image.open(BytesIO(requests.get(f"https://pollinations.ai/prompt/{prompt}?seed={user_config['seed']}&width={user_config['resolution']['width']}&height={user_config['resolution']['height']}").content)) try: - pollinations_image = Image.open(BytesIO(requests.get(generate_sd3_via_hub(prompt, user_id=user_id)).content)) + pollinations_image = Image.open(BytesIO(requests.get(await generate_sd3_via_hub(prompt, user_id=user_id)).content)) except: pollinations_image = Image.new('RGB', dalle_image.size, (0, 0, 0)) try: extra_image = { "label": "Terminus XL Velocity V2 (WIP)", - "data": Image.open(BytesIO(requests.get(generate_terminus_via_hub(prompt, user_id=user_id)).content)) + "data": Image.open(BytesIO(requests.get(await generate_terminus_via_hub(prompt, user_id=user_id)).content)) } except: extra_image = None + try: + extra_image_2 = { + "label": "PixArt 900M (WIP)", + "data": Image.open(BytesIO(requests.get(await generate_pixart_via_hub(prompt, user_id=user_id)).content)) + } + except: + extra_image_2 = None draw = ImageDraw.Draw(pollinations_image) draw.text((10, 10), "SD3 2B", (255, 255, 255), font=font, stroke_fill=(0,0,0), stroke_width=4) @@ -171,6 +196,11 @@ async def generate_image(ctx, prompt, user_id: int = None, extra_image: dict = N extra_image_vertical_position = int((height - extra_image["data"].size[1]) / 2) extra_image_position = (new_width, extra_image_vertical_position) new_width = new_width + extra_image["data"].size[0] + if extra_image_2 is not None: + extra_image_vertical_position = int((height - extra_image_2["data"].size[1]) / 2) + extra_image_position = (new_width, extra_image_vertical_position) + new_width = new_width + extra_image_2["data"].size[0] + new_image = Image.new('RGB', (new_width, height)) new_image.paste(pollinations_image, (0, 0)) new_image.paste(dalle_image, (width, 0)) @@ -180,6 +210,11 @@ async def generate_image(ctx, prompt, user_id: int = None, extra_image: dict = N draw.text((10, 10), extra_image["label"], (255, 255, 255), font=font, stroke_fill=(0,0,0), stroke_width=4) width, height = extra_image["data"].size new_image.paste(extra_image["data"], extra_image_position) + if extra_image_2 is not None: + draw = ImageDraw.Draw(extra_image_2["data"]) + draw.text((10, 10), extra_image_2["label"], (255, 255, 255), font=font, stroke_fill=(0,0,0), stroke_width=4) + width, height = extra_image_2["data"].size + new_image.paste(extra_image_2["data"], extra_image_position) # Save the new image to a BytesIO object. output = BytesIO() new_image.save(output, format="PNG")