Skip to content

Commit

Permalink
add target_x flag (not sure this impl is correct)
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Dec 3, 2023
1 parent 2952bca commit 7a4e507
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 36 deletions.
38 changes: 26 additions & 12 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2569,20 +2569,23 @@ def __getattr__(self, item):
# Gradual Latent
if args.gradual_latent_timesteps is not None:
if args.gradual_latent_unsharp_params:
ksize, sigma, strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")]
ksize = int(ksize)
us_params = args.gradual_latent_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]]
us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3]))
us_ksize = int(us_ksize)
else:
ksize, sigma, strength = None, None, None
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None

gradual_latent = GradualLatent(
args.gradual_latent_ratio,
args.gradual_latent_timesteps,
args.gradual_latent_every_n_steps,
args.gradual_latent_ratio_step,
args.gradual_latent_s_noise,
ksize,
sigma,
strength,
us_ksize,
us_sigma,
us_strength,
us_target_x,
)
pipe.set_gradual_latent(gradual_latent)

Expand Down Expand Up @@ -3348,12 +3351,23 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
if gl_timesteps < 0:
gl_timesteps = args.gradual_latent_timesteps or 650
if gl_unsharp_params is not None:
ksize, sigma, strength = [float(v) for v in gl_unsharp_params.split(",")]
ksize = int(ksize)
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
print(unsharp_params)
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:
ksize, sigma, strength = None, None, None
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
gradual_latent = GradualLatent(
gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, ksize, sigma, strength
gl_ratio,
gl_timesteps,
gl_every_n_steps,
gl_ratio_step,
gl_s_noise,
us_ksize,
us_sigma,
us_strength,
us_target_x,
)
pipe.set_gradual_latent(gradual_latent)

Expand Down Expand Up @@ -3765,8 +3779,8 @@ def setup_parser() -> argparse.ArgumentParser:
"--gradual_latent_unsharp_params",
type=str,
default=None,
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /"
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨",
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
)

return parser
Expand Down
45 changes: 28 additions & 17 deletions library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
gaussian_blur_ksize=None,
gaussian_blur_sigma=0.5,
gaussian_blur_strength=0.5,
unsharp_target_x=True,
):
self.ratio = ratio
self.start_timesteps = start_timesteps
Expand All @@ -37,12 +38,14 @@ def __init__(
self.gaussian_blur_ksize = gaussian_blur_ksize
self.gaussian_blur_sigma = gaussian_blur_sigma
self.gaussian_blur_strength = gaussian_blur_strength
self.unsharp_target_x = unsharp_target_x

def __str__(self) -> str:
return (
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength})"
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
+ f"unsharp_target_x={self.unsharp_target_x})"
)

def apply_unshark_mask(self, x: torch.Tensor):
Expand All @@ -54,6 +57,19 @@ def apply_unshark_mask(self, x: torch.Tensor):
sharpened = x + mask
return sharpened

def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.float()

x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)

# apply unsharp mask / アンシャープマスクを適用する
if unsharp and self.gaussian_blur_ksize:
x = self.apply_unshark_mask(x)

return x


class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -140,37 +156,32 @@ def step(

dt = sigma_down - sigma

prev_sample = sample + derivative * dt

device = model_output.device
if self.resized_size is None:
prev_sample = sample + derivative * dt

noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
)
s_noise = 1.0
else:
print(
"resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape
)
org_dtype = prev_sample.dtype
if org_dtype == torch.bfloat16:
prev_sample = prev_sample.float()

prev_sample = torch.nn.functional.interpolate(
prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False
).to(dtype=org_dtype)
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
s_noise = self.gradual_latent.s_noise

# apply unsharp mask / アンシャープマスクを適用する
if self.gradual_latent.gaussian_blur_ksize:
prev_sample = self.gradual_latent.apply_unshark_mask(prev_sample)
if self.gradual_latent.unsharp_target_x:
prev_sample = sample + derivative * dt
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
else:
sample = self.gradual_latent.interpolate(sample, self.resized_size)
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
prev_sample = sample + derivative * dt

noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
dtype=model_output.dtype,
device=device,
generator=generator,
)
s_noise = self.gradual_latent.s_noise

prev_sample = prev_sample + noise * sigma_up * s_noise

Expand Down
27 changes: 20 additions & 7 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,10 +1829,12 @@ def __getattr__(self, item):
# Gradual Latent
if args.gradual_latent_timesteps is not None:
if args.gradual_latent_unsharp_params:
us_ksize, us_sigma, us_strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")]
us_params = args.gradual_latent_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]]
us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3]))
us_ksize = int(us_ksize)
else:
us_ksize, us_sigma, us_strength = None, None, None
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None

gradual_latent = GradualLatent(
args.gradual_latent_ratio,
Expand All @@ -1843,6 +1845,7 @@ def __getattr__(self, item):
us_ksize,
us_sigma,
us_strength,
us_target_x,
)
pipe.set_gradual_latent(gradual_latent)

Expand Down Expand Up @@ -2650,12 +2653,22 @@ def scale_and_round(x):
if gl_timesteps < 0:
gl_timesteps = args.gradual_latent_timesteps or 650
if gl_unsharp_params is not None:
us_ksize, us_sigma, us_strength = [float(v) for v in gl_unsharp_params.split(",")]
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:
us_ksize, us_sigma, us_strength = None, None, None
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
gradual_latent = GradualLatent(
gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, us_ksize, us_sigma, us_strength
gl_ratio,
gl_timesteps,
gl_every_n_steps,
gl_ratio_step,
gl_s_noise,
us_ksize,
us_sigma,
us_strength,
us_target_x,
)
pipe.set_gradual_latent(gradual_latent)

Expand Down Expand Up @@ -3056,8 +3069,8 @@ def setup_parser() -> argparse.ArgumentParser:
"--gradual_latent_unsharp_params",
type=str,
default=None,
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /"
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨",
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
)

# # parser.add_argument(
Expand Down

0 comments on commit 7a4e507

Please sign in to comment.