Skip to content

Commit

Permalink
Merge branch 'dev' into gradual_latent_hires_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 26, 2023
2 parents 610566f + 764e333 commit 2897a89
Show file tree
Hide file tree
Showing 15 changed files with 457 additions and 141 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
12 changes: 10 additions & 2 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.original_unet import FlashAttentionFunction

from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: InferUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
clip_model: CLIPModel,
Expand Down Expand Up @@ -2365,6 +2365,7 @@ def main(args):
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)

# VAEを読み込む
if args.vae is not None:
Expand Down Expand Up @@ -2521,13 +2522,20 @@ def __getattr__(self, item):
vae = sli_vae
del sli_vae
vae.to(dtype).to(device)
vae.eval()

text_encoder.to(dtype).to(device)
unet.to(dtype).to(device)

text_encoder.eval()
unet.eval()

if clip_model is not None:
clip_model.to(dtype).to(device)
clip_model.eval()
if vgg16_model is not None:
vgg16_model.to(dtype).to(device)
vgg16_model.eval()

# networkを組み込む
if args.network_module:
Expand Down
9 changes: 6 additions & 3 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas):
noise_scheduler.alphas_cumprod = alphas_cumprod


def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss

Expand Down
Loading

0 comments on commit 2897a89

Please sign in to comment.