Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Jul 5, 2024
1 parent c2a6679 commit bf0769b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
11 changes: 7 additions & 4 deletions discord_tron_master/cogs/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ def retrieve_vlm_caption(image_url) -> str:

async def generate_pixart_via_hub(prompt:str, user_id: int = None):
from gradio_client import Client
user_config = AppConfig().get_user_config(user_id=user_id)
resolution = user_config.get("resolution", {"width": 1024, "height": 1024})
res_str = f"{resolution['width']}x{resolution['height']}"

client = Client("ptx0/PixArt-900M")
result = client.predict(
prompt="Hello!!",
guidance_scale=3.4,
prompt=prompt,
guidance_scale=user_config.get("guidance_scale", 4.4),
num_inference_steps=28,
resolution="1152x960",
negative_prompt="underexposed, blurry, ugly, washed-out",
resolution=res_str,
negative_prompt=user_config.get("negative_prompt", "blurry, ugly, cropped"),
api_name="/predict"
)
# close the connection
Expand Down
12 changes: 11 additions & 1 deletion discord_tron_master/cogs/image/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from discord_tron_master.classes.jobs.image_generation_job import ImageGenerationJob
from discord_tron_master.bot import clean_traceback
# For queue manager, etc.
from threading import Thread
discord = DiscordBot.get_instance()

from discord_tron_master.classes.guilds import Guilds
Expand Down Expand Up @@ -86,7 +87,16 @@ async def generate_range(self, ctx, count, *, prompt):
async def generate_sd3_dalle_comparison(self, ctx, *, prompt):
if guild_config.is_channel_banned(ctx.guild.id, ctx.channel.id):
return
await generate_image(ctx, prompt)
# await generate_image(ctx, prompt)
# instead, run generate_image in a thread and don't block the main thread.
try:
thread = Thread(target=generate_image, args=(ctx, prompt))
thread.start()
except Exception as e:
await ctx.send(
f"Error generating image: {e}\n\nStack trace:\n{await clean_traceback(traceback.format_exc())}"
)


@commands.command(name="dalle", help="Generates an image based on the given prompt using DALL-E.")
async def generate_dalle(self, ctx, *, prompt):
Expand Down

0 comments on commit bf0769b

Please sign in to comment.