diff --git a/library/train_util.py b/library/train_util.py index 27910dc90..3df01e3e7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,6 +19,7 @@ Sequence, Tuple, Union, + Callable, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -3476,6 +3477,25 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=0.1, help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) + parser.add_argument( + "--timestep_sampling", + choices=["uniform", "sigmoid", "shift", "flux_shift"], + default="uniform", + help="Method to sample timesteps: uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=1.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。", + ) parser.add_argument( "--lowram", @@ -5198,9 +5218,31 @@ def save_sd_model_on_train_end_common( if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + +def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, latents, device): + # Sample a random timestep for each image + b_size, _, h, w = latents.shape + + if args.timestep_sampling != "uniform": + shift = args.discrete_flow_shift + logits_norm = torch.randn(b_size, device="cpu") + logits_norm = logits_norm * args.sigmoid_scale + timesteps = logits_norm.sigmoid() + if args.timestep_sampling == "flux_shift": + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + timesteps = time_shift(mu, 1.0, timesteps) + else: + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + timesteps = min_timestep + (timesteps * (max_timestep - min_timestep)) + else: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.huber_schedule == "exponential": @@ -5223,7 +5265,6 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, timesteps = timesteps.long().to(device) return timesteps, huber_c - def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -5238,12 +5279,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount ) - # Sample a random timestep for each image - b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, latents, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process)