Skip to content

Commit

Permalink
fix to correct sinegen's waveform phase
Browse files Browse the repository at this point in the history
  • Loading branch information
AznamirWoW committed Jan 15, 2025
1 parent 9f7b7c1 commit 9f6cbb6
Showing 1 changed file with 51 additions and 86 deletions.
137 changes: 51 additions & 86 deletions rvc/lib/algorithm/generators/refinegan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.checkpoint import checkpoint

from rvc.lib.algorithm.commons import get_padding


class ResBlock(torch.nn.Module):
class ResBlock(nn.Module):
"""
Residual block with multiple dilated convolutions.
Expand Down Expand Up @@ -37,10 +39,10 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels

self.convs1 = torch.nn.ModuleList(
self.convs1 = nn.ModuleList(
[
weight_norm(
torch.nn.Conv1d(
nn.Conv1d(
in_channels=in_channels if idx == 0 else out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand All @@ -54,10 +56,10 @@ def __init__(
)
self.convs1.apply(self.init_weights)

self.convs2 = torch.nn.ModuleList(
self.convs2 = nn.ModuleList(
[
weight_norm(
torch.nn.Conv1d(
nn.Conv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand All @@ -74,10 +76,10 @@ def __init__(
def forward(self, x: torch.Tensor):
for idx, (c1, c2) in enumerate(zip(self.convs1, self.convs2)):
# new tensor
xt = torch.nn.functional.leaky_relu(x, self.leaky_relu_slope)
xt = F.leaky_relu(x, self.leaky_relu_slope)
xt = c1(xt)
# in-place call
xt = torch.nn.functional.leaky_relu_(xt, self.leaky_relu_slope)
xt = F.leaky_relu_(xt, self.leaky_relu_slope)
xt = c2(xt)

if idx != 0 or self.in_channels == self.out_channels:
Expand All @@ -93,12 +95,12 @@ def remove_parametrizations(self):
remove_parametrizations(c2)

def init_weights(self, m):
if type(m) == torch.nn.Conv1d:
if type(m) == nn.Conv1d:
m.weight.data.normal_(0, 0.01)
m.bias.data.fill_(0.0)


class AdaIN(torch.nn.Module):
class AdaIN(nn.Module):
"""
Adaptive Instance Normalization layer.
Expand All @@ -117,17 +119,17 @@ def __init__(
):
super().__init__()

self.weight = torch.nn.Parameter(torch.ones(channels))
self.weight = nn.Parameter(torch.ones(channels))
# safe to use in-place as it is used on a new x+gaussian tensor
self.activation = torch.nn.LeakyReLU(leaky_relu_slope, inplace=True)
self.activation = nn.LeakyReLU(leaky_relu_slope, inplace=True)

def forward(self, x: torch.Tensor):
gaussian = torch.randn_like(x) * self.weight[None, :, None]

return self.activation(x + gaussian)


class ParallelResBlock(torch.nn.Module):
class ParallelResBlock(nn.Module):
"""
Parallel residual block that applies multiple residual blocks with different kernel sizes in parallel.
Expand All @@ -153,17 +155,17 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels

self.input_conv = torch.nn.Conv1d(
self.input_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=7,
stride=1,
padding=3,
)

self.blocks = torch.nn.ModuleList(
self.blocks = nn.ModuleList(
[
torch.nn.Sequential(
nn.Sequential(
AdaIN(channels=out_channels),
ResBlock(
in_channels=out_channels,
Expand All @@ -190,7 +192,7 @@ def remove_parametrizations(self):
block[1].remove_parametrizations()


class SineGenerator(torch.nn.Module):
class SineGenerator(nn.Module):
"""
Definition of sine generator
Expand Down Expand Up @@ -220,6 +222,11 @@ def __init__(
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold

self.merge = nn.Sequential(
nn.Linear(self.dim, 1, bias=False),
nn.Tanh(),
)

def _f02uv(self, f0):
# generate uv signal
Expand All @@ -231,7 +238,7 @@ def _f02sine(self, f0_values):
"""f0_values: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
# convert to F0 in rad. The interger part n can be ignored
# convert to F0 in rad. The integer part n can be ignored
# because 2 * np.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1

Expand Down Expand Up @@ -267,55 +274,13 @@ def forward(self, f0):
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)

sine_waves = sine_waves * uv + noise * (1 - uv)
return sine_waves, uv, noise


class SourceModuleHnNSF(torch.nn.Module):
"""
Source Module for generating harmonic and noise signals.
This module uses a SineGenerator to produce harmonic signals based on the fundamental frequency (F0).
Args:
sampling_rate (int): Sampling rate of the audio.
harmonic_num (int, optional): Number of harmonics to generate. Defaults to 0.
sine_amp (float, optional): Amplitude of the sine wave. Defaults to 0.1.
add_noise_std (float, optional): Standard deviation of the additive noise. Defaults to 0.003.
voiced_threshold (int, optional): F0 threshold for voiced/unvoiced classification. Defaults to 0.
"""

def __init__(
self,
sampling_rate,
harmonic_num=0,
sine_amp=0.1,
add_noise_std=0.003,
voiced_threshold=0,
):
super(SourceModuleHnNSF, self).__init__()

self.sine_amp = sine_amp
self.noise_std = add_noise_std

# to produce sine waveforms
self.l_sin_gen = SineGenerator(
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold
)

# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()

def forward(self, x: torch.Tensor):
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))

return sine_merge, None, None


class RefineGANGenerator(torch.nn.Module):
sine_waves = sine_waves * uv + noise
# correct DC offset
sine_waves = sine_waves - sine_waves.mean(dim=1, keepdim=True)
# merge with grad
return self.merge(sine_waves)

class RefineGANGenerator(nn.Module):
"""
RefineGAN generator for audio synthesis.
Expand Down Expand Up @@ -344,37 +309,37 @@ def __init__(
num_mels: int = 128,
start_channels: int = 16,
gin_channels: int = 256,
checkpointing=False,
checkpointing: bool =False,
):
super().__init__()

self.downsample_rates = downsample_rates
self.upsample_rates = upsample_rates
self.leaky_relu_slope = leaky_relu_slope
self.checkpointing = checkpointing

self.f0_upsample = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num=8)
self.upp = np.prod(upsample_rates)
self.m_source = SineGenerator(sample_rate)

# expands
self.source_conv = weight_norm(
torch.nn.Conv1d(
nn.Conv1d(
in_channels=1,
out_channels=start_channels,
kernel_size=7,
stride=1,
padding=3,
bias=False
)
)

channels = start_channels
self.downsample_blocks = torch.nn.ModuleList([])
self.downsample_blocks = nn.ModuleList([])
for rate in downsample_rates:
new_channels = channels * 2

self.downsample_blocks.append(
torch.nn.Sequential(
torch.nn.Upsample(scale_factor=1 / rate, mode="linear"),
nn.Sequential(
nn.Upsample(scale_factor=1 / rate, mode="linear"),
ResBlock(
in_channels=channels,
out_channels=new_channels,
Expand All @@ -388,7 +353,7 @@ def __init__(
channels = new_channels

self.mel_conv = weight_norm(
torch.nn.Conv1d(
nn.Conv1d(
in_channels=num_mels,
out_channels=channels,
kernel_size=7,
Expand All @@ -398,18 +363,18 @@ def __init__(
)

if gin_channels != 0:
self.cond = torch.nn.Conv1d(256, channels, 1)
self.cond = nn.Conv1d(256, channels, 1)

channels *= 2

self.upsample_blocks = torch.nn.ModuleList([])
self.upsample_conv_blocks = torch.nn.ModuleList([])
self.upsample_blocks = nn.ModuleList([])
self.upsample_conv_blocks = nn.ModuleList([])

for rate in upsample_rates:
new_channels = channels // 2

self.upsample_blocks.append(
torch.nn.Upsample(scale_factor=rate, mode="linear")
nn.Upsample(scale_factor=rate, mode="linear")
)

self.upsample_conv_blocks.append(
Expand All @@ -425,7 +390,7 @@ def __init__(
channels = new_channels

self.conv_post = weight_norm(
torch.nn.Conv1d(
nn.Conv1d(
in_channels=channels,
out_channels=1,
kernel_size=7,
Expand All @@ -435,9 +400,9 @@ def __init__(
)

def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
f0 = self.f0_upsample(f0[:, None, :]).transpose(-1, -2)
har_source, _, _ = self.m_source(f0)
har_source = har_source.transpose(-1, -2)

f0 = F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear")
har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2)

# expanding pitch source to 16 channels
# new tensor
Expand All @@ -446,7 +411,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
downs = []
for i, block in enumerate(self.downsample_blocks):
# in-place call
x = torch.nn.functional.leaky_relu_(x, self.leaky_relu_slope)
x = F.leaky_relu_(x, self.leaky_relu_slope)
downs.append(x)
if self.training and self.checkpointing:
x = checkpoint(block, x, use_reentrant=False)
Expand All @@ -467,7 +432,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
reversed(downs),
):
# in-place call
x = torch.nn.functional.leaky_relu_(x, self.leaky_relu_slope)
x = F.leaky_relu_(x, self.leaky_relu_slope)

if self.training and self.checkpointing:
x = checkpoint(ups, x, use_reentrant=False)
Expand All @@ -478,7 +443,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
x = torch.cat([x, down], dim=1)
x = res(x)
# in-place call
x = torch.nn.functional.leaky_relu_(x, self.leaky_relu_slope)
x = F.leaky_relu_(x, self.leaky_relu_slope)
x = self.conv_post(x)
# in-place call
x = torch.tanh_(x)
Expand Down

0 comments on commit 9f6cbb6

Please sign in to comment.