diff --git a/rvc/train/losses.py b/rvc/train/losses.py index 2e0d4dc9..14beaec5 100644 --- a/rvc/train/losses.py +++ b/rvc/train/losses.py @@ -9,12 +9,11 @@ def feature_loss(fmap_r, fmap_g): fmap_r (list of torch.Tensor): List of reference feature maps. fmap_g (list of torch.Tensor): List of generated feature maps. """ - loss = sum( + return 2 * sum( torch.mean(torch.abs(rl - gl)) for dr, dg in zip(fmap_r, fmap_g) for rl, gl in zip(dr, dg) ) - return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): @@ -24,18 +23,10 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): Args: disc_real_outputs (list of torch.Tensor): List of discriminator outputs for real samples. disc_generated_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. - """ - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean((1 - dr.float()) ** 2) - g_loss = torch.mean(dg.float() ** 2) - - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) - loss += r_loss + g_loss - + """ + r_losses = [(1 - dr).pow(2).mean() for dr in disc_real_outputs] + g_losses = [dg.pow(2).mean() for dg in disc_generated_outputs] + loss = sum(r_losses) + sum(g_losses) return loss, r_losses, g_losses @@ -45,15 +36,51 @@ def generator_loss(disc_outputs): Args: disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. + """ + gen_losses = [(1 - dg).pow(2).mean() for dg in disc_outputs] + loss = sum(gen_losses) + return loss, gen_losses + + +def discriminator_loss_scaled(disc_real, disc_fake, scale=1.0): """ - gen_losses = [] - loss = 0 - for dg in disc_outputs: - l = torch.mean((1 - dg.float()) ** 2) - gen_losses.append(l.item()) - loss += l + Compute the scaled discriminator loss for real and generated outputs. - return loss, gen_losses + Args: + disc_real (list of torch.Tensor): List of discriminator outputs for real samples. + disc_fake (list of torch.Tensor): List of discriminator outputs for generated samples. + scale (float, optional): Scaling factor applied to losses beyond the midpoint. Default is 1.0. + """ + midpoint = len(disc_real) // 2 + losses = [] + for i, (d_real, d_fake) in enumerate(zip(disc_real, disc_fake)): + real_loss = (1 - d_real).pow(2).mean() + fake_loss = d_fake.pow(2).mean() + total_loss = real_loss + fake_loss + if i >= midpoint: + total_loss *= scale + losses.append(total_loss) + loss = sum(losses) + return loss, None, None + + +def generator_loss_scaled(disc_outputs, scale=1.0): + """ + Compute the scaled generator loss based on discriminator outputs. + + Args: + disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. + scale (float, optional): Scaling factor applied to losses beyond the midpoint. Default is 1.0. + """ + midpoint = len(disc_outputs) // 2 + losses = [] + for i, d_fake in enumerate(disc_outputs): + loss_value = (1 - d_fake).pow(2).mean() + if i >= midpoint: + loss_value *= scale + losses.append(loss_value) + loss = sum(losses) + return loss, None, None def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): @@ -67,10 +94,7 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): logs_p (torch.Tensor): Log variance of p [b, h, t_t]. z_mask (torch.Tensor): Mask for the latent variables [b, h, t_t]. """ - kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) - - kl = torch.sum(kl * z_mask) - loss = kl / torch.sum(z_mask) - + kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2 * logs_p) + kl = (kl * z_mask).sum() + loss = kl / z_mask.sum() return loss