From 9f6cbb6b307e925c442577f51979b003e0e8037d Mon Sep 17 00:00:00 2001 From: AznamirWoW <101997116+AznamirWoW@users.noreply.github.com> Date: Tue, 14 Jan 2025 19:46:10 -0500 Subject: [PATCH] fix to correct sinegen's waveform phase --- rvc/lib/algorithm/generators/refinegan.py | 137 ++++++++-------------- 1 file changed, 51 insertions(+), 86 deletions(-) diff --git a/rvc/lib/algorithm/generators/refinegan.py b/rvc/lib/algorithm/generators/refinegan.py index a25df4ee..b27db8ed 100644 --- a/rvc/lib/algorithm/generators/refinegan.py +++ b/rvc/lib/algorithm/generators/refinegan.py @@ -1,5 +1,7 @@ import numpy as np import torch +from torch import nn +from torch.nn import functional as F from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations from torch.utils.checkpoint import checkpoint @@ -7,7 +9,7 @@ from rvc.lib.algorithm.commons import get_padding -class ResBlock(torch.nn.Module): +class ResBlock(nn.Module): """ Residual block with multiple dilated convolutions. @@ -37,10 +39,10 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels - self.convs1 = torch.nn.ModuleList( + self.convs1 = nn.ModuleList( [ weight_norm( - torch.nn.Conv1d( + nn.Conv1d( in_channels=in_channels if idx == 0 else out_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -54,10 +56,10 @@ def __init__( ) self.convs1.apply(self.init_weights) - self.convs2 = torch.nn.ModuleList( + self.convs2 = nn.ModuleList( [ weight_norm( - torch.nn.Conv1d( + nn.Conv1d( in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -74,10 +76,10 @@ def __init__( def forward(self, x: torch.Tensor): for idx, (c1, c2) in enumerate(zip(self.convs1, self.convs2)): # new tensor - xt = torch.nn.functional.leaky_relu(x, self.leaky_relu_slope) + xt = F.leaky_relu(x, self.leaky_relu_slope) xt = c1(xt) # in-place call - xt = torch.nn.functional.leaky_relu_(xt, self.leaky_relu_slope) + xt = F.leaky_relu_(xt, self.leaky_relu_slope) xt = c2(xt) if idx != 0 or self.in_channels == self.out_channels: @@ -93,12 +95,12 @@ def remove_parametrizations(self): remove_parametrizations(c2) def init_weights(self, m): - if type(m) == torch.nn.Conv1d: + if type(m) == nn.Conv1d: m.weight.data.normal_(0, 0.01) m.bias.data.fill_(0.0) -class AdaIN(torch.nn.Module): +class AdaIN(nn.Module): """ Adaptive Instance Normalization layer. @@ -117,9 +119,9 @@ def __init__( ): super().__init__() - self.weight = torch.nn.Parameter(torch.ones(channels)) + self.weight = nn.Parameter(torch.ones(channels)) # safe to use in-place as it is used on a new x+gaussian tensor - self.activation = torch.nn.LeakyReLU(leaky_relu_slope, inplace=True) + self.activation = nn.LeakyReLU(leaky_relu_slope, inplace=True) def forward(self, x: torch.Tensor): gaussian = torch.randn_like(x) * self.weight[None, :, None] @@ -127,7 +129,7 @@ def forward(self, x: torch.Tensor): return self.activation(x + gaussian) -class ParallelResBlock(torch.nn.Module): +class ParallelResBlock(nn.Module): """ Parallel residual block that applies multiple residual blocks with different kernel sizes in parallel. @@ -153,7 +155,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels - self.input_conv = torch.nn.Conv1d( + self.input_conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=7, @@ -161,9 +163,9 @@ def __init__( padding=3, ) - self.blocks = torch.nn.ModuleList( + self.blocks = nn.ModuleList( [ - torch.nn.Sequential( + nn.Sequential( AdaIN(channels=out_channels), ResBlock( in_channels=out_channels, @@ -190,7 +192,7 @@ def remove_parametrizations(self): block[1].remove_parametrizations() -class SineGenerator(torch.nn.Module): +class SineGenerator(nn.Module): """ Definition of sine generator @@ -220,6 +222,11 @@ def __init__( self.dim = self.harmonic_num + 1 self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold + + self.merge = nn.Sequential( + nn.Linear(self.dim, 1, bias=False), + nn.Tanh(), + ) def _f02uv(self, f0): # generate uv signal @@ -231,7 +238,7 @@ def _f02sine(self, f0_values): """f0_values: (batchsize, length, dim) where dim indicates fundamental tone and overtones """ - # convert to F0 in rad. The interger part n can be ignored + # convert to F0 in rad. The integer part n can be ignored # because 2 * np.pi * n doesn't affect phase rad_values = (f0_values / self.sampling_rate) % 1 @@ -267,55 +274,13 @@ def forward(self, f0): noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) - sine_waves = sine_waves * uv + noise * (1 - uv) - return sine_waves, uv, noise - - -class SourceModuleHnNSF(torch.nn.Module): - """ - Source Module for generating harmonic and noise signals. - - This module uses a SineGenerator to produce harmonic signals based on the fundamental frequency (F0). - - Args: - sampling_rate (int): Sampling rate of the audio. - harmonic_num (int, optional): Number of harmonics to generate. Defaults to 0. - sine_amp (float, optional): Amplitude of the sine wave. Defaults to 0.1. - add_noise_std (float, optional): Standard deviation of the additive noise. Defaults to 0.003. - voiced_threshold (int, optional): F0 threshold for voiced/unvoiced classification. Defaults to 0. - """ - - def __init__( - self, - sampling_rate, - harmonic_num=0, - sine_amp=0.1, - add_noise_std=0.003, - voiced_threshold=0, - ): - super(SourceModuleHnNSF, self).__init__() - - self.sine_amp = sine_amp - self.noise_std = add_noise_std - - # to produce sine waveforms - self.l_sin_gen = SineGenerator( - sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold - ) - - # to merge source harmonics into a single excitation - self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) - self.l_tanh = torch.nn.Tanh() - - def forward(self, x: torch.Tensor): - sine_wavs, uv, _ = self.l_sin_gen(x) - sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) - - return sine_merge, None, None - - -class RefineGANGenerator(torch.nn.Module): + sine_waves = sine_waves * uv + noise + # correct DC offset + sine_waves = sine_waves - sine_waves.mean(dim=1, keepdim=True) + # merge with grad + return self.merge(sine_waves) + +class RefineGANGenerator(nn.Module): """ RefineGAN generator for audio synthesis. @@ -344,37 +309,37 @@ def __init__( num_mels: int = 128, start_channels: int = 16, gin_channels: int = 256, - checkpointing=False, + checkpointing: bool =False, ): super().__init__() - self.downsample_rates = downsample_rates self.upsample_rates = upsample_rates self.leaky_relu_slope = leaky_relu_slope self.checkpointing = checkpointing - self.f0_upsample = torch.nn.Upsample(scale_factor=np.prod(upsample_rates)) - self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num=8) + self.upp = np.prod(upsample_rates) + self.m_source = SineGenerator(sample_rate) # expands self.source_conv = weight_norm( - torch.nn.Conv1d( + nn.Conv1d( in_channels=1, out_channels=start_channels, kernel_size=7, stride=1, padding=3, + bias=False ) ) channels = start_channels - self.downsample_blocks = torch.nn.ModuleList([]) + self.downsample_blocks = nn.ModuleList([]) for rate in downsample_rates: new_channels = channels * 2 self.downsample_blocks.append( - torch.nn.Sequential( - torch.nn.Upsample(scale_factor=1 / rate, mode="linear"), + nn.Sequential( + nn.Upsample(scale_factor=1 / rate, mode="linear"), ResBlock( in_channels=channels, out_channels=new_channels, @@ -388,7 +353,7 @@ def __init__( channels = new_channels self.mel_conv = weight_norm( - torch.nn.Conv1d( + nn.Conv1d( in_channels=num_mels, out_channels=channels, kernel_size=7, @@ -398,18 +363,18 @@ def __init__( ) if gin_channels != 0: - self.cond = torch.nn.Conv1d(256, channels, 1) + self.cond = nn.Conv1d(256, channels, 1) channels *= 2 - self.upsample_blocks = torch.nn.ModuleList([]) - self.upsample_conv_blocks = torch.nn.ModuleList([]) + self.upsample_blocks = nn.ModuleList([]) + self.upsample_conv_blocks = nn.ModuleList([]) for rate in upsample_rates: new_channels = channels // 2 self.upsample_blocks.append( - torch.nn.Upsample(scale_factor=rate, mode="linear") + nn.Upsample(scale_factor=rate, mode="linear") ) self.upsample_conv_blocks.append( @@ -425,7 +390,7 @@ def __init__( channels = new_channels self.conv_post = weight_norm( - torch.nn.Conv1d( + nn.Conv1d( in_channels=channels, out_channels=1, kernel_size=7, @@ -435,9 +400,9 @@ def __init__( ) def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): - f0 = self.f0_upsample(f0[:, None, :]).transpose(-1, -2) - har_source, _, _ = self.m_source(f0) - har_source = har_source.transpose(-1, -2) + + f0 = F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear") + har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2) # expanding pitch source to 16 channels # new tensor @@ -446,7 +411,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): downs = [] for i, block in enumerate(self.downsample_blocks): # in-place call - x = torch.nn.functional.leaky_relu_(x, self.leaky_relu_slope) + x = F.leaky_relu_(x, self.leaky_relu_slope) downs.append(x) if self.training and self.checkpointing: x = checkpoint(block, x, use_reentrant=False) @@ -467,7 +432,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): reversed(downs), ): # in-place call - x = torch.nn.functional.leaky_relu_(x, self.leaky_relu_slope) + x = F.leaky_relu_(x, self.leaky_relu_slope) if self.training and self.checkpointing: x = checkpoint(ups, x, use_reentrant=False) @@ -478,7 +443,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): x = torch.cat([x, down], dim=1) x = res(x) # in-place call - x = torch.nn.functional.leaky_relu_(x, self.leaky_relu_slope) + x = F.leaky_relu_(x, self.leaky_relu_slope) x = self.conv_post(x) # in-place call x = torch.tanh_(x)