diff --git a/predict.py b/predict.py index d354611..944d90f 100644 --- a/predict.py +++ b/predict.py @@ -7,18 +7,24 @@ import time from cog import BasePredictor, Input, Path, BaseModel import torch -from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, EulerDiscreteScheduler, UNet2DConditionModel, StableDiffusionXLPipeline +from diffusers import ( + AutoPipelineForText2Image, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + UNet2DConditionModel, + StableDiffusionXLPipeline, +) from huggingface_hub import hf_hub_download from safetensors.torch import load_file os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -MODEL_URL = "https://weights.replicate.delivery/default/res-adapter/Lykon/dreamshaper-xl-1-0.tar" -MODEL_WEIGHTS = "pretrained/Lykon/dreamshaper-xl-1-0" +SDXL_MODEL_URL = "https://weights.replicate.delivery/default/res-adapter/Lykon/dreamshaper-xl-1-0.tar" +SDXL_MODEL_WEIGHTS = "pretrained/Lykon/dreamshaper-xl-1-0" +SD15_MODEL_URL = "https://weights.replicate.delivery/default/res-adapter/dreamlike-art/dreamlike-diffusion-1.0.tar" +SD15_MODEL_WEIGHTS = "pretrained/dreamlike-art/dreamlike-diffusion-1.0" -# For SDXL, SDXL-Lightning, dreamshaper-xl-1-0, -# For SDv1.5, dreamlike-diffusion-1.0 class ModelOutput(BaseModel): without_res_adapter: Optional[Path] @@ -39,30 +45,60 @@ def download_weights(url, dest, extract=True): class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" - if not os.path.exists(MODEL_WEIGHTS): - download_weights(MODEL_URL, MODEL_WEIGHTS) - self.default_pipe = AutoPipelineForText2Image.from_pretrained( - MODEL_WEIGHTS, torch_dtype=torch.float16, variant="fp16" + if not os.path.exists(SDXL_MODEL_WEIGHTS): + download_weights(SDXL_MODEL_URL, SDXL_MODEL_WEIGHTS) + if not os.path.exists(SD15_MODEL_WEIGHTS): + download_weights(SD15_MODEL_URL, SD15_MODEL_WEIGHTS) + + # load "Lykon/dreamshaper-xl-1-0" + self.sdxl_pipe = AutoPipelineForText2Image.from_pretrained( + SDXL_MODEL_WEIGHTS, torch_dtype=torch.float16, variant="fp16" ) - self.default_pipe.scheduler = DPMSolverMultistepScheduler.from_config( - self.default_pipe.scheduler.config, + self.sdxl_pipe.scheduler = DPMSolverMultistepScheduler.from_config( + self.sdxl_pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++", ) - self.default_pipe = self.default_pipe.to("cuda") + self.sdxl_pipe = self.sdxl_pipe.to("cuda") + + # load "ByteDance/SDXL-Lightning" + self.sdxl_lightning_pipe = AutoPipelineForText2Image.from_pretrained( + SDXL_MODEL_WEIGHTS, torch_dtype=torch.float16, variant="fp16" + ) + repo = "ByteDance/SDXL-Lightning" + ckpt = "sdxl_lightning_4step_unet.safetensors" + # Load SDXL-Lightning to UNet + unet = self.sdxl_lightning_pipe.unet + unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) + # Change UNet to pipeline + self.sdxl_lightning_pipe.unet = unet + self.sdxl_lightning_pipe.scheduler = EulerDiscreteScheduler.from_config( + self.sdxl_lightning_pipe.scheduler.config, timestep_spacing="trailing" + ) + self.sdxl_lightning_pipe = self.sdxl_lightning_pipe.to("cuda") + + # load "dreamlike-art/dreamlike-diffusion-1.0" + self.sd15_pipe = AutoPipelineForText2Image.from_pretrained( + SD15_MODEL_WEIGHTS + ) # fp16 not available for "dreamlike-art/dreamlike-diffusion-1.0" + self.sd15_pipe.scheduler = DPMSolverMultistepScheduler.from_config( + self.sd15_pipe.scheduler.config, + use_karras_sigmas=True, + algorithm_type="sde-dpmsolver++", + ) + self.sd15_pipe = self.sd15_pipe.to("cuda") @torch.inference_mode() def predict( self, - base_model: str = Input( - description="Choose a stable diffusion architecture, supporint sd1.5 and sdxl.", - default="sdxl", - choices=["sd1.5", "sdxl"], - ), model_name: str = Input( - description="Name of a stable diffusion model, should have either sd1.5 or sdxl architecture.", + description="Choose a stable diffusion model.", default="ByteDance/SDXL-Lightning", - choice=["Lykon/dreamshaper-xl-1-0", "ByteDance/SDXL-Lightning", "dreamlike-art/dreamlike-diffusion-1.0"] + choices=[ + "Lykon/dreamshaper-xl-1-0", + "ByteDance/SDXL-Lightning", + "dreamlike-art/dreamlike-diffusion-1.0", + ], ), prompt: str = Input( description="Input prompt", @@ -72,14 +108,8 @@ def predict( description="Specify things to not see in the output", default="ugly, deformed, noisy, blurry, nsfw, low contrast, text, BadDream, 3d, cgi, render, fake, anime, open mouth, big forehead, long neck", ), - width: int = Input( - description="Width of output image", - default=512, - ), - height: int = Input( - description="Height of output image", - default=512, - ), + width: int = Input(description="Width of output image", default=512), + height: int = Input(description="Height of output image", default=512), num_inference_steps: int = Input( description="Number of denoising steps", default=4 ), @@ -101,44 +131,25 @@ def predict( generator = torch.Generator("cuda").manual_seed(seed) + base_model = ( + "sd1.5" if model_name == "dreamlike-art/dreamlike-diffusion-1.0" else "sdxl" + ) + if model_name == "Lykon/dreamshaper-xl-1-0": - self.pipe = self.default_pipe + pipe = self.sdxl_pipe elif model_name == "ByteDance/SDXL-Lightning": - self.pipe = self.default_pipe - repo = "ByteDance/SDXL-Lightning" - ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting! - - # Load SDXL-Lightning to UNet - unet = self.default_pipe.unet - unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) - - # Change UNet to pipeline - self.pipe.unet = unet - self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing") + pipe = self.sdxl_lightning_pipe else: - try: - self.pipe = AutoPipelineForText2Image.from_pretrained( - model_name, torch_dtype=torch.float16, variant="fp16" - ) - except: - print("fp16 not available.") - self.pipe = AutoPipelineForText2Image.from_pretrained(model_name) - - self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( - self.pipe.scheduler.config, - use_karras_sigmas=True, - algorithm_type="sde-dpmsolver++", - ) - self.pipe = self.pipe.to("cuda") + pipe = self.sd15_pipe if show_baseline: - if len(self.pipe.get_active_adapters()) > 0: + if len(pipe.get_active_adapters()) > 0: print("Unloading LoRA weights...") - self.pipe.unload_lora_weights() + pipe.unload_lora_weights() print("Generating images without res_adapter...") - baseline_image = self.pipe( + baseline_image = pipe( prompt, negative_prompt=negative_prompt, width=width, @@ -150,40 +161,43 @@ def predict( baseline_path = "/tmp/baseline.png" baseline_image.save(baseline_path) - if len(self.pipe.get_active_adapters()) == 0: + if len(pipe.get_active_adapters()) == 0: if base_model == "sd1.5": print("Loading Resolution LoRA weights...") - self.pipe.load_lora_weights( + pipe.load_lora_weights( hf_hub_download( repo_id="jiaxiangc/res-adapter", - subfolder=f"sd1.5", + subfolder="sd1.5", filename="resolution_lora.safetensors", ), adapter_name="res_adapter", ) - self.pipe.set_adapters(["res_adapter"], adapter_weights=[1.0]) + pipe.set_adapters(["res_adapter"], adapter_weights=[1.0]) print("Load Resolution Norm weights") - self.pipe.unet.load_state_dict(load_file( - hf_hub_download( - repo_id="jiaxiangc/res-adapter", - subfolder="sd1.5", - filename="resolution_normalization.safetensors" + pipe.unet.load_state_dict( + load_file( + hf_hub_download( + repo_id="jiaxiangc/res-adapter", + subfolder="sd1.5", + filename="resolution_normalization.safetensors", + ) ), - ), strict=False) + strict=False, + ) elif base_model == "sdxl": print("Loading Resolution LoRA weights...") - self.pipe.load_lora_weights( + pipe.load_lora_weights( hf_hub_download( repo_id="jiaxiangc/res-adapter", - subfolder=f"sdxl-i", + subfolder="sdxl-i", filename="resolution_lora.safetensors", ), adapter_name="res_adapter", ) - self.pipe.set_adapters(["res_adapter"], adapter_weights=[1.0]) + pipe.set_adapters(["res_adapter"], adapter_weights=[1.0]) print("Generating images with res_adapter...") - image = self.pipe( + image = pipe( prompt, negative_prompt=negative_prompt, width=width,