Skip to content

Commit

Permalink
simplified pitch guidance method
Browse files Browse the repository at this point in the history
added low pass filter in attempt to suppress aliasing
  • Loading branch information
AznamirWoW committed Jan 24, 2025
1 parent 7a6ad4f commit 8d6398b
Showing 1 changed file with 53 additions and 41 deletions.
94 changes: 53 additions & 41 deletions rvc/lib/algorithm/generators/refinegan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import numpy as np
import torch
from torch import nn
Expand Down Expand Up @@ -311,71 +312,89 @@ def __init__(
start_channels: int = 16,
gin_channels: int = 256,
checkpointing: bool = False,
upsample_initial_channel = 512
):
super().__init__()
self.downsample_rates = downsample_rates

self.upsample_rates = upsample_rates
self.leaky_relu_slope = leaky_relu_slope
self.checkpointing = checkpointing

self.upp = np.prod(upsample_rates)
self.m_source = SineGenerator(sample_rate)

# expands
self.source_conv = weight_norm(
# expanded f0 sinegen -> match mel_conv
self.pre_conv = weight_norm(
nn.Conv1d(
in_channels=1,
out_channels=start_channels,
out_channels=upsample_initial_channel // 2,
kernel_size=7,
stride=1,
padding=3,
bias=False,
)
)

channels = start_channels
self.downsample_blocks = nn.ModuleList([])
for rate in downsample_rates:
new_channels = channels * 2
stride_f0s = [
math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1
for i in range(len(upsample_rates))
]

channels = upsample_initial_channel

self.downsample_blocks = nn.ModuleList([])
for i, u in enumerate(upsample_rates):
# handling odd upsampling rates
stride = stride_f0s[i]
kernel = 1 if stride == 1 else stride * 2 - stride % 2
padding = 0 if stride == 1 else (kernel - stride) // 2

# f0 input gets upscaled to full segment size, then downscaled back to match each upscale step

self.downsample_blocks.append(
nn.Sequential(
nn.Upsample(scale_factor=1 / rate, mode="linear"),
ResBlock(
in_channels=channels,
out_channels=new_channels,
kernel_size=7,
dilation=(1, 3, 5),
leaky_relu_slope=leaky_relu_slope,
),
nn.Conv1d(
in_channels=1,
out_channels=channels // 2 ** (i + 2),
kernel_size=kernel,
stride=stride,
padding = padding
)
)

channels = new_channels

self.mel_conv = weight_norm(
nn.Conv1d(
in_channels=num_mels,
out_channels=channels,
out_channels=channels // 2,
kernel_size=7,
stride=1,
padding=3,
)
)

if gin_channels != 0:
self.cond = nn.Conv1d(256, channels, 1)

channels *= 2
self.cond = nn.Conv1d(256, channels // 2, 1)

self.upsample_blocks = nn.ModuleList([])
self.upsample_conv_blocks = nn.ModuleList([])
self.filters = nn.ModuleList([])

for rate in upsample_rates:
new_channels = channels // 2

self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))

low_pass = nn.Conv1d(
channels,
channels,
kernel_size=15,
padding=7,
groups=channels,
bias=False)

low_pass.weight.data.fill_(1.0 / 15)

self.filters.append(low_pass)

self.upsample_conv_blocks.append(
ParallelResBlock(
in_channels=channels + channels // 4,
Expand All @@ -397,6 +416,7 @@ def __init__(
padding=3,
)
)


def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):

Expand All @@ -405,20 +425,8 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
)
har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2)

# expanding pitch source to 16 channels
# new tensor
x = self.source_conv(har_source)
# making a downscaled version to match upscaler stages
downs = []
for i, block in enumerate(self.downsample_blocks):
# in-place call
x = F.leaky_relu_(x, self.leaky_relu_slope)
downs.append(x)
if self.training and self.checkpointing:
x = checkpoint(block, x, use_reentrant=False)
else:
x = block(x)

x = self.pre_conv(har_source)
x = F.interpolate(x, size=mel.shape[-1], mode="linear")
# expanding spectrogram from 192 to 256 channels
mel = self.mel_conv(mel)

Expand All @@ -427,22 +435,26 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
mel += self.cond(g)
x = torch.cat([mel, x], dim=1)

for ups, res, down in zip(
for ups, res, down, flt in zip(
self.upsample_blocks,
self.upsample_conv_blocks,
reversed(downs),
self.downsample_blocks,
self.filters,
):
# in-place call
x = F.leaky_relu_(x, self.leaky_relu_slope)

if self.training and self.checkpointing:
x = checkpoint(ups, x, use_reentrant=False)
x = torch.cat([x, down], dim=1)
x = checkpoint(flt, x, use_reentrant=False)
x = torch.cat([x, down(har_source)], dim=1)
x = checkpoint(res, x, use_reentrant=False)
else:
x = ups(x)
x = torch.cat([x, down], dim=1)
x = flt(x)
x = torch.cat([x, down(har_source)], dim=1)
x = res(x)

# in-place call
x = F.leaky_relu_(x, self.leaky_relu_slope)
x = self.conv_post(x)
Expand Down

0 comments on commit 8d6398b

Please sign in to comment.