forked from faraday/runway-stable-diffusion-inpainting
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add vqgan loss with codebook statistic eval
- Loading branch information
Showing
1 changed file
with
167 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from einops import repeat | ||
|
||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init | ||
from taming.modules.losses.lpips import LPIPS | ||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss | ||
|
||
|
||
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): | ||
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] | ||
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) | ||
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) | ||
loss_real = (weights * loss_real).sum() / weights.sum() | ||
loss_fake = (weights * loss_fake).sum() / weights.sum() | ||
d_loss = 0.5 * (loss_real + loss_fake) | ||
return d_loss | ||
|
||
def adopt_weight(weight, global_step, threshold=0, value=0.): | ||
if global_step < threshold: | ||
weight = value | ||
return weight | ||
|
||
|
||
def measure_perplexity(predicted_indices, n_embed): | ||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py | ||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally | ||
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) | ||
avg_probs = encodings.mean(0) | ||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() | ||
cluster_use = torch.sum(avg_probs > 0) | ||
return perplexity, cluster_use | ||
|
||
def l1(x, y): | ||
return torch.abs(x-y) | ||
|
||
|
||
def l2(x, y): | ||
return torch.pow((x-y), 2) | ||
|
||
|
||
class VQLPIPSWithDiscriminator(nn.Module): | ||
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, | ||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, | ||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, | ||
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", | ||
pixel_loss="l1"): | ||
super().__init__() | ||
assert disc_loss in ["hinge", "vanilla"] | ||
assert perceptual_loss in ["lpips", "clips", "dists"] | ||
assert pixel_loss in ["l1", "l2"] | ||
self.codebook_weight = codebook_weight | ||
self.pixel_weight = pixelloss_weight | ||
if perceptual_loss == "lpips": | ||
print(f"{self.__class__.__name__}: Running with LPIPS.") | ||
self.perceptual_loss = LPIPS().eval() | ||
else: | ||
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") | ||
self.perceptual_weight = perceptual_weight | ||
|
||
if pixel_loss == "l1": | ||
self.pixel_loss = l1 | ||
else: | ||
self.pixel_loss = l2 | ||
|
||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, | ||
n_layers=disc_num_layers, | ||
use_actnorm=use_actnorm, | ||
ndf=disc_ndf | ||
).apply(weights_init) | ||
self.discriminator_iter_start = disc_start | ||
if disc_loss == "hinge": | ||
self.disc_loss = hinge_d_loss | ||
elif disc_loss == "vanilla": | ||
self.disc_loss = vanilla_d_loss | ||
else: | ||
raise ValueError(f"Unknown GAN loss '{disc_loss}'.") | ||
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") | ||
self.disc_factor = disc_factor | ||
self.discriminator_weight = disc_weight | ||
self.disc_conditional = disc_conditional | ||
self.n_classes = n_classes | ||
|
||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): | ||
if last_layer is not None: | ||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] | ||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] | ||
else: | ||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] | ||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] | ||
|
||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) | ||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() | ||
d_weight = d_weight * self.discriminator_weight | ||
return d_weight | ||
|
||
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, | ||
global_step, last_layer=None, cond=None, split="train", predicted_indices=None): | ||
if not exists(codebook_loss): | ||
codebook_loss = torch.tensor([0.]).to(inputs.device) | ||
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) | ||
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) | ||
if self.perceptual_weight > 0: | ||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) | ||
rec_loss = rec_loss + self.perceptual_weight * p_loss | ||
else: | ||
p_loss = torch.tensor([0.0]) | ||
|
||
nll_loss = rec_loss | ||
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] | ||
nll_loss = torch.mean(nll_loss) | ||
|
||
# now the GAN part | ||
if optimizer_idx == 0: | ||
# generator update | ||
if cond is None: | ||
assert not self.disc_conditional | ||
logits_fake = self.discriminator(reconstructions.contiguous()) | ||
else: | ||
assert self.disc_conditional | ||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) | ||
g_loss = -torch.mean(logits_fake) | ||
|
||
try: | ||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) | ||
except RuntimeError: | ||
assert not self.training | ||
d_weight = torch.tensor(0.0) | ||
|
||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) | ||
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() | ||
|
||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), | ||
"{}/quant_loss".format(split): codebook_loss.detach().mean(), | ||
"{}/nll_loss".format(split): nll_loss.detach().mean(), | ||
"{}/rec_loss".format(split): rec_loss.detach().mean(), | ||
"{}/p_loss".format(split): p_loss.detach().mean(), | ||
"{}/d_weight".format(split): d_weight.detach(), | ||
"{}/disc_factor".format(split): torch.tensor(disc_factor), | ||
"{}/g_loss".format(split): g_loss.detach().mean(), | ||
} | ||
if predicted_indices is not None: | ||
assert self.n_classes is not None | ||
with torch.no_grad(): | ||
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) | ||
log[f"{split}/perplexity"] = perplexity | ||
log[f"{split}/cluster_usage"] = cluster_usage | ||
return loss, log | ||
|
||
if optimizer_idx == 1: | ||
# second pass for discriminator update | ||
if cond is None: | ||
logits_real = self.discriminator(inputs.contiguous().detach()) | ||
logits_fake = self.discriminator(reconstructions.contiguous().detach()) | ||
else: | ||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) | ||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) | ||
|
||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) | ||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) | ||
|
||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), | ||
"{}/logits_real".format(split): logits_real.detach().mean(), | ||
"{}/logits_fake".format(split): logits_fake.detach().mean() | ||
} | ||
return d_loss, log |