diff --git a/main.py b/main.py index d78de2a..0dd16f3 100644 --- a/main.py +++ b/main.py @@ -55,35 +55,38 @@ def main(): print(f"Create {output_dir}") # #### 2.Load pipeline and scheduler #### - if config.task == "t2i": + task = config.get("task", None) + if task == "t2i": pipeline = load_text2image_pipeline(config) - if config.task == "t2i_accelerate": + elif task == "t2i_accelerate": pipeline = load_text2image_lcm_lora_pipeline(config) - if config.task == "controlnet": + elif task == "controlnet": pipeline = load_controlnet_pipeline(config) - if config.task == "ip_adapter": + elif task == "ip_adapter": pipeline = load_ip_adapter_pipeline(config) + else: + raise NotImplementedError - device = torch.device(f"cuda:{config.device}") + device = torch.device(f"cuda:{config.get('device', 0)}") pipeline = pipeline.to(device) - if config.enable_xformers: + if config.get("enable_xformers", None): print("Enable xformers successfully.") pipeline.enable_xformers_memory_efficient_attention() # #### 3.Get prompts and other condition #### - p_prompts = config.prompts - n_prompt = config.n_prompt + p_prompts = config.get("prompts", []) + n_prompt = config.get("n_prompt", "") - if config.task == "controlnet": + if task == "controlnet": condition_images = [] source_images = [] - for image_path in config.source_images: + for image_path in config.get("source_images", []): source_image = Image.open(image_path) - if config.scale_ratio: + if config.get("scale_ratio", None): width, height = int(source_image.size[0]*config.scale_ratio), int(source_image.size[1]*config.scale_ratio) else: - width, height = config.width, config.height + width, height = config.get("width", 512), config.get("height", 512) source_image = source_image.resize((width, height)) source_images.append(source_image) np_condition = np.array(source_image) @@ -94,42 +97,42 @@ def main(): condition_image.save(os.path.join(output_dir, f"condition_{Path(image_path).stem}.jpg")) condition_images.append(condition_image) - if config.task == "t2i_adapter": + if task == "t2i_adapter": condition_images = [] - for condition_path in config.condition_images: + for condition_path in config.get("condition_images", []): condition_image = Image.open(condition_path) - if config.scale_ratio: + if config.get("scale_ratio", None): width, height = int(condition_image.size[0]*config.scale_ratio), int(condition_image.size[1]*config.scale_ratio) else: - width, height = config.width, config.height + width, height = config.get("width", 512), config.get("height", 512) condition_image = condition_image.resize((width, height)).convert("L") condition_images.append(condition_image) - if config.task == "ip_adapter": + if task == "ip_adapter": + sub_task = config.get("sub_task", None) # Image Variation - if config.sub_task == "image_variation": + if sub_task == "image_variation": ip_adapter_images = [] - for ip_image_path in config.ip_adapter_images: + for ip_image_path in config.get("ip_adapter_images", []): ip_adpater_image = Image.open(ip_image_path) - if config.scale_ratio: + if config.get("scale_ratio", None): width, height = int(ip_adpater_image.size[0]*config.scale_ratio), int(ip_adpater_image.size[1]*config.scale_ratio) else: - width, height = config.width, config.height + width, height = config.get("width", 512), config.get("height", 512) ip_adpater_image = ip_adpater_image.resize((width, height)) - ip_adapter_images.append(ip_adpater_image) # Image to Image - if config.sub_task == "image_to_image": + elif sub_task == "image_to_image": ip_adapter_images = [] source_images = [] - for ip_image_path, image_path in zip(config.ip_adapter_images, config.source_images): + for ip_image_path, image_path in zip(config.get("ip_adapter_images", []), config.get("source_images", [])): source_image = Image.open(image_path) - if config.scale_ratio: + if config.get("scale_ratio", None): width, height = int(source_image.size[0]*config.scale_ratio), int(source_image.size[1]*config.scale_ratio) else: - width, height = config.width, config.height + width, height = config.get("width", 512), config.get("height", 512) source_image = source_image.resize((width, height)) source_images.append(source_image) @@ -138,17 +141,17 @@ def main(): ip_adapter_images.append(ip_adapter_image) # Image Inpainting - if config.sub_task == "inpaint": + elif sub_task == "inpaint": source_images = [] mask_images = [] ip_adapter_images = [] - for ip_image_path, image_path, mask_path in zip(config.ip_adapter_images, config.source_images, config.mask_images): + for ip_image_path, image_path, mask_path in zip(config.get("ip_adapter_images", []), config.get("source_images", []), config.get("mask_images", [])): source_image = Image.open(image_path) - if config.scale_ratio: + if config.get("scale_ratio", None): width, height = int(source_image.size[0]*config.scale_ratio), int(source_image.size[1]*config.scale_ratio) else: - width, height = config.width, config.height + width, height = config.get("width", 512), config.get("height", 512) source_image = source_image.resize((width, height)) source_images.append(source_image) @@ -160,15 +163,17 @@ def main(): ip_adapter_image = ip_adapter_image.resize((width, height)) ip_adapter_images.append(ip_adapter_image) + else: + raise NotImplementedError # #### 4.Inference pipeline #### - if config.seed: + if config.get("seed", None): generator = torch.Generator(device=device).manual_seed(config.seed) else: generator = None - if config.res_adapter_model == "": + if config.get("res_adapter_model", "") == "": enable_compare = False else: enable_compare = config.enable_compare @@ -177,76 +182,76 @@ def main(): # Inference baseline original_images = [] for i, prompt in tqdm(enumerate(p_prompts), total=len(p_prompts), desc="[Baselines]: "): - if config.task == "t2i" or config.task == "t2i_accelerate": + if task == "t2i" or task == "t2i_accelerate": kwargs = {} - if config.task == "controlnet": - if config.sub_task == "text_to_image": + if task == "controlnet": + if sub_task == "text_to_image": kwargs = {"image": condition_images[i]} - if config.sub_task == "image_to_image": + if sub_task == "image_to_image": kwargs = {"control_image": condition_images[i], "image": source_images[i]} - if config.task == "t2i_adapter": + if task == "t2i_adapter": kwargs = {"image": condition_images[i]} - if config.task == "ip_adapter": - if config.sub_task == "image_variation": + if task == "ip_adapter": + if sub_task == "image_variation": kwargs = {"ip_adapter_image": ip_adapter_images[i]} - if config.sub_task == "image_to_image": + if sub_task == "image_to_image": kwargs = {"image": source_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.6} - if config.sub_task == "inpaint": + if sub_task == "inpaint": kwargs = {"image": source_images[i], "mask_image": mask_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.5} images = pipeline( prompt=prompt, - height=config.height, - width=config.width, + height=config.get("height", 512), + width=config.get("width", 512), negative_prompt=n_prompt, - num_inference_steps=config.num_inference_steps, - num_images_per_prompt=config.num_images_per_prompt, + num_inference_steps=config.get("num_inference_steps", 25), + num_images_per_prompt=config.get("num_images_per_prompt", 2), generator=generator, output_type="pt", - guidance_scale=config.guidance_scale, + guidance_scale=config.get("guidance_scale", 7.5), **kwargs, ).images original_images.append(images) # Load res-adapter - if config.res_adapter_model != "": + if config.get("res_adapter_model", "") != "": pipeline = load_resadapter(pipeline, config) print(f"Load res-adapter from {config.res_adapter_model}") - pipeline.set_adapters(["res_adapter"], adapter_weights=[config.res_adapter_alpha]) + pipeline.set_adapters(["res_adapter"], adapter_weights=[config.get("res_adapter_alpha", 1.0)]) if config.task == "t2i_accelerate": - pipeline.set_adapters(["res_adapter", "lcm_lora"], adapter_weights=[config.res_adapter_alpha, config.lcm_lora_alpha]) + pipeline.set_adapters(["res_adapter", "lcm_lora"], adapter_weights=[config.get("res_adapter_alpha", 1.0), config.get("lcm_lora_alpha", 1.0)]) # Inference with res-adapter resadapter_images = [] for i, prompt in tqdm(enumerate(p_prompts), total=len(p_prompts), desc="[ResAdapter]: "): - if config.task == "t2i" or config.task == "t2i_accelerate": + if task == "t2i" or task == "t2i_accelerate": kwargs = {} - if config.task == "controlnet": - if config.sub_task == "text_to_image": + if task == "controlnet": + if sub_task == "text_to_image": kwargs = {"image": condition_images[i]} - if config.sub_task == "image_to_image": + if sub_task == "image_to_image": kwargs = {"control_image": condition_images[i], "image": source_images[i]} - if config.task == "t2i_adapter": + if task == "t2i_adapter": kwargs = {"image": condition_images[i]} - if config.task == "ip_adapter": - if config.sub_task == "image_variation": + if task == "ip_adapter": + if sub_task == "image_variation": kwargs = {"ip_adapter_image": ip_adapter_images[i]} - if config.sub_task == "image_to_image": + if sub_task == "image_to_image": kwargs = {"image": source_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.6} - if config.sub_task == "inpaint": + if sub_task == "inpaint": kwargs = {"image": source_images[i], "mask_image": mask_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.5} images = pipeline( prompt=prompt, - height=config.height, - width=config.width, + height=config.get("height", 512), + width=config.get("width", 512), negative_prompt=n_prompt, - num_inference_steps=config.num_inference_steps, - num_images_per_prompt=config.num_images_per_prompt, + num_inference_steps=config.get("num_inference_steps", 25), + num_images_per_prompt=config.get("num_images_per_prompt", 2), generator=generator, output_type="pt", - guidance_scale=config.guidance_scale, + guidance_scale=config.get("guidance_scale", 7.5), **kwargs, ).images resadapter_images.append(images) @@ -254,13 +259,13 @@ def main(): # Save images texts = ["ResAdapter", "Baseline"] if enable_compare: - for j in range(config.num_images_per_prompt): + for j in range(config.get("num_images_per_prompt", 2)): compare_image = torch.stack([resadapter_images[i][j], original_images[i][j]]) - if config.draw_text: + if config.get("draw_text", None): for k in range(len(texts)): compare_image[k] = draw_text_on_images(compare_image[k], texts[k]) - if config.split_images: + if config.get("split_images", None): for q in range(len(texts)): save_image( compare_image[q], os.path.join(output_dir, f"{prompt[:100]}_{j}_{texts[q]}.jpg"), normalize=True, value_range=(0, 1), nrow=2, padding=0, @@ -271,7 +276,7 @@ def main(): ) else: compare_image = resadapter_images[i] - for m in range(config.num_images_per_prompt): + for m in range(config.get("num_images_per_prompt", 2)): save_image( compare_image[m], os.path.join(output_dir, f"{prompt[:100]}_{m}.jpg"), normalize=True, value_range=(0, 1), nrow=2, padding=0, ) diff --git a/resadapter/model_loader.py b/resadapter/model_loader.py index ef75853..d910477 100644 --- a/resadapter/model_loader.py +++ b/resadapter/model_loader.py @@ -18,8 +18,9 @@ # Load resadapter for scripts def load_resadapter(pipeline, config): - NORM_WEIGHTS_NAME = "resolution_normalization.safetensors" - LORA_WEIGHTS_NAME = "resolution_lora.safetensors" + + NORM_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" + LORA_WEIGHTS_NAME = "pytorch_lora_weights.safetensors" # Load resolution normalization try: diff --git a/resadapter/pipeline_loader.py b/resadapter/pipeline_loader.py index 2e6e637..aaef3b9 100644 --- a/resadapter/pipeline_loader.py +++ b/resadapter/pipeline_loader.py @@ -38,17 +38,18 @@ UniPCMultistepScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler, + EulerDiscreteScheduler, ) def load_text2image_pipeline(config): - if config.personalized_model.endswith( + if config.diffusion_model.endswith( ".safetensors" - ) or config.personalized_model.endswith(".ckpt"): - print(f"Load pipeline from civitai: {config.personalized_model}") + ) or config.diffusion_model.endswith(".ckpt"): + print(f"Load pipeline from civitai: {config.diffusion_model}") if config.model_type == "sd1.5": pipeline = StableDiffusionPipeline.from_single_file( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, @@ -56,17 +57,17 @@ def load_text2image_pipeline(config): ) else: pipeline = StableDiffusionXLPipeline.from_single_file( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, requires_safety_checker=False, ) else: - print(f"Load pipeline from huggingface: {config.personalized_model}") + print(f"Load pipeline from huggingface: {config.diffusion_model}") if config.model_type == "sd1.5": pipeline = StableDiffusionPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, @@ -74,7 +75,7 @@ def load_text2image_pipeline(config): ) else: pipeline = StableDiffusionXLPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, @@ -87,17 +88,21 @@ def load_text2image_pipeline(config): algorithm_type="sde-dpmsolver++", ) + # if config.timestep_spacing == "trailing": + # print("Detect timestep_spacing == trailing") + # pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing") + return pipeline def load_text2image_lcm_lora_pipeline(config): - if config.personalized_model.endswith( + if config.diffusion_model.endswith( ".safetensors" - ) or config.personalized_model.endswith(".ckpt"): - print(f"Load pipeline from civitai: {config.personalized_model}") + ) or config.diffusion_model.endswith(".ckpt"): + print(f"Load pipeline from civitai: {config.diffusion_model}") if config.model_type == "sd1.5": pipeline = StableDiffusionPipeline.from_single_file( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, @@ -105,17 +110,17 @@ def load_text2image_lcm_lora_pipeline(config): ) else: pipeline = StableDiffusionXLPipeline.from_single_file( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, requires_safety_checker=False, ) else: - print(f"Load pipeline from huggingface: {config.personalized_model}") + print(f"Load pipeline from huggingface: {config.diffusion_model}") if config.model_type == "sd1.5": pipeline = StableDiffusionPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, @@ -123,7 +128,7 @@ def load_text2image_lcm_lora_pipeline(config): ) else: pipeline = StableDiffusionXLPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, torch_dtype=torch.float16, variant="fp16", load_safety_checker=False, @@ -144,43 +149,43 @@ def load_controlnet_pipeline(config): if config.model_type == "sd1.5": if config.sub_task == "image_to_image": - if config.personalized_model.endswith(".safetensors") or config.personalized_model.endswith(".ckpt"): + if config.diffusion_model.endswith(".safetensors") or config.diffusion_model.endswith(".ckpt"): pipeline = StableDiffusionControlNetImg2ImgPipeline.from_single_file( - config.personalized_model, + config.diffusion_model, controlnet=controlnet, torch_dtype=torch.float16, ) else: pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, controlnet=controlnet, torch_dtype=torch.float16, ) if config.sub_task == "text_to_image": pipeline = StableDiffusionControlNetPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, ) elif config.model_type == "sdxl": if config.sub_task == "image_to_image": - if config.personalized_model.endswith(".safetensors") or config.personalized_model.endswith(".ckpt"): + if config.diffusion_model.endswith(".safetensors") or config.diffusion_model.endswith(".ckpt"): pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file( - config.personalized_model, + config.diffusion_model, controlnet=controlnet, torch_dtype=torch.float16, ) else: pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, controlnet=controlnet, torch_dtype=torch.float16, ) if config.sub_task == "text_to_image": pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( - config.personalized_model, + config.diffusion_model, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, @@ -188,7 +193,7 @@ def load_controlnet_pipeline(config): pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) print(f"Load controlnet form {config.controlnet_model}") - print(f"Load model form {config.personalized_model}") + print(f"Load model form {config.diffusion_model}") return pipeline @@ -196,15 +201,15 @@ def load_controlnet_pipeline(config): def load_ip_adapter_pipeline(config): if config.sub_task == "image_variation": pipeline = AutoPipelineForText2Image.from_pretrained( - config.personalized_model, torch_dtype=torch.float16, safety_checker=None, + config.diffusion_model, torch_dtype=torch.float16, safety_checker=None, ) if config.sub_task == "image_to_image": pipeline = AutoPipelineForImage2Image.from_pretrained( - config.personalized_model, torch_dtype=torch.float16, safety_checker=None, + config.diffusion_model, torch_dtype=torch.float16, safety_checker=None, ) if config.sub_task == "inpaint": pipeline = AutoPipelineForInpainting.from_pretrained( - config.personalized_model, torch_dtype=torch.float16, safety_checker=None, + config.diffusion_model, torch_dtype=torch.float16, safety_checker=None, ) if config.model_type == "sd1.5": if config.ip_adapter_weight_name == "general":