Skip to content

Commit

Permalink
added a learnable upscaler with a filter to remove mirored harmonics
Browse files Browse the repository at this point in the history
  • Loading branch information
AznamirWoW committed Jan 1, 2025
1 parent 79489ab commit bf2e6fd
Showing 1 changed file with 61 additions and 7 deletions.
68 changes: 61 additions & 7 deletions rvc/lib/algorithm/generators/refinegan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import numpy as np
import torch
from torch.nn.utils.parametrizations import weight_norm
Expand All @@ -6,6 +7,55 @@

from rvc.lib.algorithm.commons import get_padding

def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
even = kernel_size % 2 == 0
half_size = kernel_size // 2

delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)

if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * torch.sinc(2 * cutoff * time)
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter


class UpSample1d(torch.nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=kernel_size
)
self.register_buffer("filter", filter)

def forward(self, x):
_, C, _ = x.shape
x = torch.nn.functional.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * torch.nn.functional.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right] # noqa
return x

class ResBlock(torch.nn.Module):
"""
Expand Down Expand Up @@ -114,19 +164,22 @@ def __init__(
*,
channels: int,
leaky_relu_slope: float = 0.2,
use_noise_gen = True,
):
super().__init__()

self.use_noise_gen = use_noise_gen
self.weight = torch.nn.Parameter(torch.ones(channels))
# safe to use in-place as it is used on a new x+gaussian tensor
self.activation = torch.nn.LeakyReLU(leaky_relu_slope, inplace=True)

def forward(self, x: torch.Tensor):
gaussian = torch.randn_like(x) * self.weight[None, :, None]

return self.activation(x + gaussian)


if self.use_noise_gen:
gaussian = torch.randn_like(x) * self.weight[None, :, None]
return self.activation(x + gaussian)
else:
return self.activation(x)

class ParallelResBlock(torch.nn.Module):
"""
Parallel residual block that applies multiple residual blocks with different kernel sizes in parallel.
Expand Down Expand Up @@ -172,7 +225,8 @@ def __init__(
dilation=dilation,
leaky_relu_slope=leaky_relu_slope,
),
AdaIN(channels=out_channels),
# disabled a second noise inductor as one is enough
AdaIN(channels=out_channels, use_noise_gen = False),
)
for kernel_size in kernel_sizes
]
Expand Down Expand Up @@ -409,7 +463,7 @@ def __init__(
new_channels = channels // 2

self.upsample_blocks.append(
torch.nn.Upsample(scale_factor=rate, mode="linear")
UpSample1d(rate) # upsampler borrowed from BigVGAN, filters out mirrored harmonics
)

self.upsample_conv_blocks.append(
Expand Down

0 comments on commit bf2e6fd

Please sign in to comment.