Skip to content

Commit

Permalink
some algorithm cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf committed Dec 3, 2024
1 parent e8482f0 commit 2a763e1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 100 deletions.
13 changes: 10 additions & 3 deletions rvc/lib/algorithm/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask

for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
Expand All @@ -79,8 +80,8 @@ def forward(self, x, x_mask):
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x

return x * x_mask


class TextEncoder(torch.nn.Module):
Expand Down Expand Up @@ -196,11 +197,17 @@ def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
):
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)

stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask

logs_exp = torch.exp(logs)
z = m + torch.randn_like(m) * logs_exp
z = z * x_mask

return z, m, logs, x_mask

def remove_weight_norm(self):
Expand Down
20 changes: 9 additions & 11 deletions rvc/lib/algorithm/nsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

from rvc.lib.algorithm.generators import SineGenerator
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock1, ResBlock2
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock
from rvc.lib.algorithm.commons import init_weights


Expand Down Expand Up @@ -92,7 +92,6 @@ def __init__(
self.conv_pre = torch.nn.Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock_cls = ResBlock1 if resblock == "1" else ResBlock2

self.ups = torch.nn.ModuleList()
self.noise_convs = torch.nn.ModuleList()
Expand Down Expand Up @@ -131,7 +130,7 @@ def __init__(

self.resblocks = torch.nn.ModuleList(
[
resblock_cls(channels[i], k, d)
ResBlock(channels[i], k, d)
for i in range(len(self.ups))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)
]
Expand All @@ -149,27 +148,26 @@ def __init__(
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
har_source, _, _ = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)

x = self.conv_pre(x)

if g is not None:
x = x + self.cond(g)
x += self.cond(g)

for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
x = ups(x)
x = x + noise_convs(har_source)
x += noise_convs(har_source)

xs = sum(
[
resblock(x)
for j, resblock in enumerate(self.resblocks)
if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)
]
self.resblocks[j](x)
for j in range(i * self.num_kernels, (i + 1) * self.num_kernels)
)
x = xs / self.num_kernels
x = xs / self.num_kernels

x = torch.nn.functional.leaky_relu(x)
x = torch.tanh(self.conv_post(x))

return x

def remove_weight_norm(self):
Expand Down
101 changes: 40 additions & 61 deletions rvc/lib/algorithm/synthesizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from typing import Optional

from rvc.lib.algorithm.nsf import GeneratorNSF
from rvc.lib.algorithm.generators import Generator
from rvc.lib.algorithm.commons import slice_segments, rand_slice_segments
Expand Down Expand Up @@ -58,26 +57,11 @@ def __init__(
sr,
use_f0,
text_enc_hidden_dim=768,
**kwargs
**kwargs,
):
super(Synthesizer, self).__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = float(p_dropout)
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
super().__init__()
self.segment_size = segment_size
self.gin_channels = gin_channels
self.spk_embed_dim = spk_embed_dim
self.use_f0 = use_f0

self.enc_p = TextEncoder(
Expand All @@ -87,7 +71,7 @@ def __init__(
n_heads,
n_layers,
kernel_size,
float(p_dropout),
p_dropout,
text_enc_hidden_dim,
f0=use_f0,
)
Expand Down Expand Up @@ -127,47 +111,38 @@ def __init__(
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
inter_channels,
hidden_channels,
5,
1,
3,
gin_channels=gin_channels,
)
self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
self.emb_g = torch.nn.Embedding(spk_embed_dim, gin_channels)

def _remove_weight_norm_from(self, module):
"""Utility to remove weight normalization from a module."""
for hook in module._forward_pre_hooks.values():
if getattr(hook, "__class__", None).__name__ == "WeightNorm":
torch.nn.utils.remove_weight_norm(module)

def remove_weight_norm(self):
"""Removes weight normalization from the model."""
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
for module in [self.dec, self.flow, self.enc_q]:
self._remove_weight_norm_from(module)

def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
self.remove_weight_norm()
return self

@torch.jit.ignore
def forward(
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
pitchf: Optional[torch.Tensor] = None,
y: torch.Tensor = None,
y_lengths: torch.Tensor = None,
y: Optional[torch.Tensor] = None,
y_lengths: Optional[torch.Tensor] = None,
ds: Optional[torch.Tensor] = None,
):
"""
Expand All @@ -180,22 +155,25 @@ def forward(
pitchf (torch.Tensor, optional): Fine-grained pitch sequence.
y (torch.Tensor, optional): Target spectrogram.
y_lengths (torch.Tensor, optional): Lengths of the target spectrograms.
ds (torch.Tensor, optional): Speaker embedding. Defaults to None.
ds (torch.Tensor, optional): Speaker embedding.
"""
g = self.emb_g(ds).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)

if y is not None:
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
if self.use_f0:

if self.use_f0 and pitchf is not None:
pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
o = self.dec(z_slice, pitchf, g=g)
else:
o = self.dec(z_slice, g=g)

return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
else:
return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)

return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)

@torch.jit.export
def infer(
Expand All @@ -216,22 +194,23 @@ def infer(
pitch (torch.Tensor, optional): Pitch sequence.
nsff0 (torch.Tensor, optional): Fine-grained pitch sequence.
sid (torch.Tensor): Speaker embedding.
rate (torch.Tensor, optional): Rate for time-stretching. Defaults to None.
rate (torch.Tensor, optional): Rate for time-stretching.
"""
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask

if rate is not None:
assert isinstance(rate, torch.Tensor)
head = int(z_p.shape[2] * (1.0 - rate.item()))
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
if self.use_f0:
z_p, x_mask = z_p[:, :, head:], x_mask[:, :, head:]
if self.use_f0 and nsff0 is not None:
nsff0 = nsff0[:, head:]
if self.use_f0:
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(z * x_mask, nsff0, g=g)
else:
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(z * x_mask, g=g)

z = self.flow(z_p, x_mask, g=g, reverse=True)
o = (
self.dec(z * x_mask, nsff0, g=g)
if self.use_f0
else self.dec(z * x_mask, g=g)
)

return o, x_mask, (z, z_p, m_p, logs_p)
41 changes: 16 additions & 25 deletions rvc/train/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch


def feature_loss(fmap_r, fmap_g):
"""
Compute the feature loss between reference and generated feature maps.
Expand All @@ -9,13 +8,11 @@ def feature_loss(fmap_r, fmap_g):
fmap_r (list of torch.Tensor): List of reference feature maps.
fmap_g (list of torch.Tensor): List of generated feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))

loss = sum(
torch.mean(torch.abs(rl - gl))
for dr, dg in zip(fmap_r, fmap_g)
for rl, gl in zip(dr, dg)
)
return loss * 2


Expand All @@ -31,13 +28,12 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_loss = torch.mean((1 - dr.float()) ** 2)
g_loss = torch.mean(dg.float() ** 2)

r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
loss += r_loss + g_loss

return loss, r_losses, g_losses

Expand All @@ -49,12 +45,11 @@ def generator_loss(disc_outputs):
Args:
disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples.
"""
loss = 0
gen_losses = []
loss = 0
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
l = torch.mean((1 - dg.float()) ** 2)
gen_losses.append(l.item())
loss += l

return loss, gen_losses
Expand All @@ -71,14 +66,10 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
logs_p (torch.Tensor): Log variance of p [b, h, t_t].
z_mask (torch.Tensor): Mask for the latent variables [b, h, t_t].
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()

kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)

kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
loss = kl / torch.sum(z_mask)

return loss

0 comments on commit 2a763e1

Please sign in to comment.