Skip to content

Commit

Permalink
Fix latent config information
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaxiangc committed Mar 31, 2024
1 parent 6908f1a commit c54228d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 97 deletions.
139 changes: 72 additions & 67 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,38 @@ def main():
print(f"Create {output_dir}")

# #### 2.Load pipeline and scheduler ####
if config.task == "t2i":
task = config.get("task", None)
if task == "t2i":
pipeline = load_text2image_pipeline(config)
if config.task == "t2i_accelerate":
elif task == "t2i_accelerate":
pipeline = load_text2image_lcm_lora_pipeline(config)
if config.task == "controlnet":
elif task == "controlnet":
pipeline = load_controlnet_pipeline(config)
if config.task == "ip_adapter":
elif task == "ip_adapter":
pipeline = load_ip_adapter_pipeline(config)
else:
raise NotImplementedError

device = torch.device(f"cuda:{config.device}")
device = torch.device(f"cuda:{config.get('device', 0)}")
pipeline = pipeline.to(device)

if config.enable_xformers:
if config.get("enable_xformers", None):
print("Enable xformers successfully.")
pipeline.enable_xformers_memory_efficient_attention()

# #### 3.Get prompts and other condition ####
p_prompts = config.prompts
n_prompt = config.n_prompt
p_prompts = config.get("prompts", [])
n_prompt = config.get("n_prompt", "")

if config.task == "controlnet":
if task == "controlnet":
condition_images = []
source_images = []
for image_path in config.source_images:
for image_path in config.get("source_images", []):
source_image = Image.open(image_path)
if config.scale_ratio:
if config.get("scale_ratio", None):
width, height = int(source_image.size[0]*config.scale_ratio), int(source_image.size[1]*config.scale_ratio)
else:
width, height = config.width, config.height
width, height = config.get("width", 512), config.get("height", 512)
source_image = source_image.resize((width, height))
source_images.append(source_image)
np_condition = np.array(source_image)
Expand All @@ -94,42 +97,42 @@ def main():
condition_image.save(os.path.join(output_dir, f"condition_{Path(image_path).stem}.jpg"))
condition_images.append(condition_image)

if config.task == "t2i_adapter":
if task == "t2i_adapter":
condition_images = []
for condition_path in config.condition_images:
for condition_path in config.get("condition_images", []):
condition_image = Image.open(condition_path)
if config.scale_ratio:
if config.get("scale_ratio", None):
width, height = int(condition_image.size[0]*config.scale_ratio), int(condition_image.size[1]*config.scale_ratio)
else:
width, height = config.width, config.height
width, height = config.get("width", 512), config.get("height", 512)
condition_image = condition_image.resize((width, height)).convert("L")
condition_images.append(condition_image)

if config.task == "ip_adapter":
if task == "ip_adapter":
sub_task = config.get("sub_task", None)
# Image Variation
if config.sub_task == "image_variation":
if sub_task == "image_variation":
ip_adapter_images = []
for ip_image_path in config.ip_adapter_images:
for ip_image_path in config.get("ip_adapter_images", []):
ip_adpater_image = Image.open(ip_image_path)
if config.scale_ratio:
if config.get("scale_ratio", None):
width, height = int(ip_adpater_image.size[0]*config.scale_ratio), int(ip_adpater_image.size[1]*config.scale_ratio)
else:
width, height = config.width, config.height
width, height = config.get("width", 512), config.get("height", 512)

ip_adpater_image = ip_adpater_image.resize((width, height))

ip_adapter_images.append(ip_adpater_image)

# Image to Image
if config.sub_task == "image_to_image":
elif sub_task == "image_to_image":
ip_adapter_images = []
source_images = []
for ip_image_path, image_path in zip(config.ip_adapter_images, config.source_images):
for ip_image_path, image_path in zip(config.get("ip_adapter_images", []), config.get("source_images", [])):
source_image = Image.open(image_path)
if config.scale_ratio:
if config.get("scale_ratio", None):
width, height = int(source_image.size[0]*config.scale_ratio), int(source_image.size[1]*config.scale_ratio)
else:
width, height = config.width, config.height
width, height = config.get("width", 512), config.get("height", 512)
source_image = source_image.resize((width, height))
source_images.append(source_image)

Expand All @@ -138,17 +141,17 @@ def main():
ip_adapter_images.append(ip_adapter_image)

# Image Inpainting
if config.sub_task == "inpaint":
elif sub_task == "inpaint":
source_images = []
mask_images = []
ip_adapter_images = []

for ip_image_path, image_path, mask_path in zip(config.ip_adapter_images, config.source_images, config.mask_images):
for ip_image_path, image_path, mask_path in zip(config.get("ip_adapter_images", []), config.get("source_images", []), config.get("mask_images", [])):
source_image = Image.open(image_path)
if config.scale_ratio:
if config.get("scale_ratio", None):
width, height = int(source_image.size[0]*config.scale_ratio), int(source_image.size[1]*config.scale_ratio)
else:
width, height = config.width, config.height
width, height = config.get("width", 512), config.get("height", 512)
source_image = source_image.resize((width, height))
source_images.append(source_image)

Expand All @@ -160,15 +163,17 @@ def main():
ip_adapter_image = ip_adapter_image.resize((width, height))
ip_adapter_images.append(ip_adapter_image)

else:
raise NotImplementedError

# #### 4.Inference pipeline ####

if config.seed:
if config.get("seed", None):
generator = torch.Generator(device=device).manual_seed(config.seed)
else:
generator = None

if config.res_adapter_model == "":
if config.get("res_adapter_model", "") == "":
enable_compare = False
else:
enable_compare = config.enable_compare
Expand All @@ -177,90 +182,90 @@ def main():
# Inference baseline
original_images = []
for i, prompt in tqdm(enumerate(p_prompts), total=len(p_prompts), desc="[Baselines]: "):
if config.task == "t2i" or config.task == "t2i_accelerate":
if task == "t2i" or task == "t2i_accelerate":
kwargs = {}
if config.task == "controlnet":
if config.sub_task == "text_to_image":
if task == "controlnet":
if sub_task == "text_to_image":
kwargs = {"image": condition_images[i]}
if config.sub_task == "image_to_image":
if sub_task == "image_to_image":
kwargs = {"control_image": condition_images[i], "image": source_images[i]}
if config.task == "t2i_adapter":
if task == "t2i_adapter":
kwargs = {"image": condition_images[i]}
if config.task == "ip_adapter":
if config.sub_task == "image_variation":
if task == "ip_adapter":
if sub_task == "image_variation":
kwargs = {"ip_adapter_image": ip_adapter_images[i]}
if config.sub_task == "image_to_image":
if sub_task == "image_to_image":
kwargs = {"image": source_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.6}
if config.sub_task == "inpaint":
if sub_task == "inpaint":
kwargs = {"image": source_images[i], "mask_image": mask_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.5}

images = pipeline(
prompt=prompt,
height=config.height,
width=config.width,
height=config.get("height", 512),
width=config.get("width", 512),
negative_prompt=n_prompt,
num_inference_steps=config.num_inference_steps,
num_images_per_prompt=config.num_images_per_prompt,
num_inference_steps=config.get("num_inference_steps", 25),
num_images_per_prompt=config.get("num_images_per_prompt", 2),
generator=generator,
output_type="pt",
guidance_scale=config.guidance_scale,
guidance_scale=config.get("guidance_scale", 7.5),
**kwargs,
).images
original_images.append(images)

# Load res-adapter
if config.res_adapter_model != "":
if config.get("res_adapter_model", "") != "":
pipeline = load_resadapter(pipeline, config)
print(f"Load res-adapter from {config.res_adapter_model}")
pipeline.set_adapters(["res_adapter"], adapter_weights=[config.res_adapter_alpha])
pipeline.set_adapters(["res_adapter"], adapter_weights=[config.get("res_adapter_alpha", 1.0)])

if config.task == "t2i_accelerate":
pipeline.set_adapters(["res_adapter", "lcm_lora"], adapter_weights=[config.res_adapter_alpha, config.lcm_lora_alpha])
pipeline.set_adapters(["res_adapter", "lcm_lora"], adapter_weights=[config.get("res_adapter_alpha", 1.0), config.get("lcm_lora_alpha", 1.0)])

# Inference with res-adapter
resadapter_images = []
for i, prompt in tqdm(enumerate(p_prompts), total=len(p_prompts), desc="[ResAdapter]: "):
if config.task == "t2i" or config.task == "t2i_accelerate":
if task == "t2i" or task == "t2i_accelerate":
kwargs = {}
if config.task == "controlnet":
if config.sub_task == "text_to_image":
if task == "controlnet":
if sub_task == "text_to_image":
kwargs = {"image": condition_images[i]}
if config.sub_task == "image_to_image":
if sub_task == "image_to_image":
kwargs = {"control_image": condition_images[i], "image": source_images[i]}
if config.task == "t2i_adapter":
if task == "t2i_adapter":
kwargs = {"image": condition_images[i]}
if config.task == "ip_adapter":
if config.sub_task == "image_variation":
if task == "ip_adapter":
if sub_task == "image_variation":
kwargs = {"ip_adapter_image": ip_adapter_images[i]}
if config.sub_task == "image_to_image":
if sub_task == "image_to_image":
kwargs = {"image": source_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.6}
if config.sub_task == "inpaint":
if sub_task == "inpaint":
kwargs = {"image": source_images[i], "mask_image": mask_images[i], "ip_adapter_image": ip_adapter_images[i], "strength": 0.5}

images = pipeline(
prompt=prompt,
height=config.height,
width=config.width,
height=config.get("height", 512),
width=config.get("width", 512),
negative_prompt=n_prompt,
num_inference_steps=config.num_inference_steps,
num_images_per_prompt=config.num_images_per_prompt,
num_inference_steps=config.get("num_inference_steps", 25),
num_images_per_prompt=config.get("num_images_per_prompt", 2),
generator=generator,
output_type="pt",
guidance_scale=config.guidance_scale,
guidance_scale=config.get("guidance_scale", 7.5),
**kwargs,
).images
resadapter_images.append(images)

# Save images
texts = ["ResAdapter", "Baseline"]
if enable_compare:
for j in range(config.num_images_per_prompt):
for j in range(config.get("num_images_per_prompt", 2)):
compare_image = torch.stack([resadapter_images[i][j], original_images[i][j]])
if config.draw_text:
if config.get("draw_text", None):
for k in range(len(texts)):
compare_image[k] = draw_text_on_images(compare_image[k], texts[k])

if config.split_images:
if config.get("split_images", None):
for q in range(len(texts)):
save_image(
compare_image[q], os.path.join(output_dir, f"{prompt[:100]}_{j}_{texts[q]}.jpg"), normalize=True, value_range=(0, 1), nrow=2, padding=0,
Expand All @@ -271,7 +276,7 @@ def main():
)
else:
compare_image = resadapter_images[i]
for m in range(config.num_images_per_prompt):
for m in range(config.get("num_images_per_prompt", 2)):
save_image(
compare_image[m], os.path.join(output_dir, f"{prompt[:100]}_{m}.jpg"), normalize=True, value_range=(0, 1), nrow=2, padding=0,
)
Expand Down
5 changes: 3 additions & 2 deletions resadapter/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

# Load resadapter for scripts
def load_resadapter(pipeline, config):
NORM_WEIGHTS_NAME = "resolution_normalization.safetensors"
LORA_WEIGHTS_NAME = "resolution_lora.safetensors"

NORM_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
LORA_WEIGHTS_NAME = "pytorch_lora_weights.safetensors"

# Load resolution normalization
try:
Expand Down
Loading

0 comments on commit c54228d

Please sign in to comment.