diff --git a/wetts/vits/data_utils.py b/wetts/vits/data_utils.py index 1037fbb..a400ecc 100644 --- a/wetts/vits/data_utils.py +++ b/wetts/vits/data_utils.py @@ -5,8 +5,8 @@ import torchaudio import torch.utils.data -from mel_processing import spectrogram_torch -from utils import load_filepaths_and_text +from utils.mel_processing import spectrogram_torch +from utils.task import load_filepaths_and_text class TextAudioSpeakerLoader(torch.utils.data.Dataset): diff --git a/wetts/vits/export_onnx.py b/wetts/vits/export_onnx.py index 110e081..2e9af11 100644 --- a/wetts/vits/export_onnx.py +++ b/wetts/vits/export_onnx.py @@ -17,8 +17,8 @@ import torch -from models import SynthesizerTrn -import utils +from model.models import SynthesizerTrn +from utils import task def get_args(): @@ -43,7 +43,7 @@ def main(): args = get_args() os.environ["CUDA_VISIBLE_DEVICES"] = "0" - hps = utils.get_hparams_from_file(args.cfg) + hps = task.get_hparams_from_file(args.cfg) hps['model']['is_onnx'] = True phone_num = len(open(args.phone_table).readlines()) @@ -56,7 +56,7 @@ def main(): n_speakers=num_speakers, **hps.model ) - utils.load_checkpoint(args.checkpoint, net_g, None) + task.load_checkpoint(args.checkpoint, net_g, None) net_g.flow.remove_weight_norm() net_g.dec.remove_weight_norm() net_g.forward = net_g.export_forward diff --git a/wetts/vits/inference.py b/wetts/vits/inference.py index c7d44fc..c7e6184 100644 --- a/wetts/vits/inference.py +++ b/wetts/vits/inference.py @@ -21,8 +21,8 @@ from scipy.io import wavfile import torch -from models import SynthesizerTrn -import utils +from model.models import SynthesizerTrn +from utils import task def get_args(): @@ -59,7 +59,7 @@ def main(): arr = line.strip().split() assert len(arr) == 2 speaker_dict[arr[0]] = int(arr[1]) - hps = utils.get_hparams_from_file(args.cfg) + hps = task.get_hparams_from_file(args.cfg) net_g = SynthesizerTrn( len(phone_dict), @@ -71,7 +71,7 @@ def main(): net_g = net_g.to(device) net_g.eval() - utils.load_checkpoint(args.checkpoint, net_g, None) + task.load_checkpoint(args.checkpoint, net_g, None) for line in open(args.test_file): audio_path, speaker, text = line.strip().split("|") diff --git a/wetts/vits/inference_onnx.py b/wetts/vits/inference_onnx.py index 10167ec..9f786a2 100644 --- a/wetts/vits/inference_onnx.py +++ b/wetts/vits/inference_onnx.py @@ -19,7 +19,7 @@ from scipy.io import wavfile import torch -import utils +from utils import task def to_numpy(tensor): @@ -61,7 +61,7 @@ def main(): arr = line.strip().split() assert len(arr) == 2 speaker_dict[arr[0]] = int(arr[1]) - hps = utils.get_hparams_from_file(args.cfg) + hps = task.get_hparams_from_file(args.cfg) ort_sess = ort.InferenceSession(args.onnx_model, providers=[args.providers]) scales = torch.FloatTensor([0.667, 1.0, 0.8]) diff --git a/wetts/vits/attentions.py b/wetts/vits/model/attentions.py similarity index 99% rename from wetts/vits/attentions.py rename to wetts/vits/model/attentions.py index 68ca0b4..ecd9ca0 100644 --- a/wetts/vits/attentions.py +++ b/wetts/vits/model/attentions.py @@ -4,8 +4,8 @@ from torch import nn from torch.nn import functional as F -import commons -from modules import LayerNorm +from model.normalization import LayerNorm +from utils import commons class Encoder(nn.Module): diff --git a/wetts/vits/model/decoder.py b/wetts/vits/model/decoder.py new file mode 100644 index 0000000..7b6a1a2 --- /dev/null +++ b/wetts/vits/model/decoder.py @@ -0,0 +1,311 @@ +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from torchaudio.transforms import InverseSpectrogram + +from model.modules import LRELU_SLOPE +from model.normalization import LayerNorm +from utils.commons import init_weights, get_padding +from utils.stft import OnnxSTFT + + +class Generator(nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class ConvNeXtLayer(nn.Module): + def __init__(self, channels, h_channels, scale): + super().__init__() + self.dw_conv = nn.Conv1d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + ) + self.norm = LayerNorm(channels) + self.pw_conv1 = nn.Conv1d(channels, h_channels, 1) + self.pw_conv2 = nn.Conv1d(h_channels, channels, 1) + self.scale = nn.Parameter( + torch.full(size=(1, channels, 1), fill_value=scale), requires_grad=True + ) + + def forward(self, x): + res = x + x = self.dw_conv(x) + x = self.norm(x) + x = self.pw_conv1(x) + x = F.gelu(x) + x = self.pw_conv2(x) + x = self.scale * x + x = res + x + return x + + +class VocosGenerator(nn.Module): + def __init__( + self, + in_channels, + channels, + h_channels, + out_channels, + num_layers, + istft_config, + gin_channels, + is_onnx=False + ): + super().__init__() + + self.pad = nn.ReflectionPad1d([1, 0]) + self.in_conv = nn.Conv1d(in_channels, channels, kernel_size=1, padding=0) + self.cond = Conv1d(gin_channels, channels, 1) + self.norm_pre = LayerNorm(channels) + scale = 1 / num_layers + self.layers = nn.ModuleList( + [ConvNeXtLayer(channels, h_channels, scale) for _ in range(num_layers)] + ) + self.norm_post = LayerNorm(channels) + self.out_conv = nn.Conv1d(channels, out_channels, kernel_size=1) + self.is_onnx = is_onnx + + if self.is_onnx: + self.stft = OnnxSTFT(filter_length=istft_config['n_fft'], + hop_length=istft_config['hop_length'], + win_length=istft_config['win_length']) + else: + self.istft = InverseSpectrogram(**istft_config) + + def forward(self, x, g=None): + x = self.pad(x) + x = self.in_conv(x) + self.cond(g) + x = self.norm_pre(x) + for layer in self.layers: + x = layer(x) + x = self.norm_post(x) + x = self.out_conv(x) + mag, phase = x.chunk(2, dim=1) + mag = mag.exp().clamp_max(max=1e2) + if self.is_onnx: + o = self.stft.inverse(mag, phase).to(x.device) + else: + s = mag * (phase.cos() + 1j * phase.sin()) + o = self.istft(s).unsqueeze(1) + return o + + def remove_weight_norm(self): + pass diff --git a/wetts/vits/model/discriminators.py b/wetts/vits/model/discriminators.py new file mode 100644 index 0000000..38fd395 --- /dev/null +++ b/wetts/vits/model/discriminators.py @@ -0,0 +1,145 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import Conv1d, Conv2d +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm + +from model.modules import LRELU_SLOPE +from utils.commons import get_padding + + +class DiscriminatorP(nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/wetts/vits/model/duration_predictors.py b/wetts/vits/model/duration_predictors.py new file mode 100644 index 0000000..0c4af6d --- /dev/null +++ b/wetts/vits/model/duration_predictors.py @@ -0,0 +1,296 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from model.modules import Flip +from model.normalization import LayerNorm +from utils.transforms import piecewise_rational_quadratic_transform + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=256, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = Log() + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.flows.append(Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.post_flows.append(Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/wetts/vits/model/encoders.py b/wetts/vits/model/encoders.py new file mode 100644 index 0000000..6585b7d --- /dev/null +++ b/wetts/vits/model/encoders.py @@ -0,0 +1,94 @@ +import math + +import torch +from torch import nn + +import model.attentions as attentions +from model.modules import WN +from utils import commons + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask diff --git a/wetts/vits/model/flows.py b/wetts/vits/model/flows.py new file mode 100644 index 0000000..04f8c4e --- /dev/null +++ b/wetts/vits/model/flows.py @@ -0,0 +1,114 @@ +import torch +from torch import nn + +from model.modules import Flip, WN + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=256, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i, l in enumerate(self.flows): + if i % 2 == 0: + l.remove_weight_norm() + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=256, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + def remove_weight_norm(self): + self.enc.remove_weight_norm() diff --git a/wetts/vits/model/models.py b/wetts/vits/model/models.py new file mode 100644 index 0000000..3fe9c7b --- /dev/null +++ b/wetts/vits/model/models.py @@ -0,0 +1,262 @@ +import math +import time + +import torch +from torch import nn +import monotonic_align + +from model.decoder import Generator, VocosGenerator +from model.duration_predictors import StochasticDurationPredictor, DurationPredictor +from model.encoders import TextEncoder, PosteriorEncoder +from model.flows import ResidualCouplingBlock +from utils import commons + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + segment_size, + inter_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[8, 8, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[16, 16, 4, 4], + n_speakers=1, + gin_channels=256, + use_sdp=True, + vocoder_type="hifigan", + vocos_channels=512, + vocos_h_channels=1536, + vocos_out_channels=1026, + vocos_num_layers=8, + vocos_istft_config={ + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "center": True, + }, + is_onnx=False, + **kwargs + ): + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.use_sdp = use_sdp + + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + if vocoder_type == "vocos": + self.dec = VocosGenerator( + inter_channels, + vocos_channels, + vocos_h_channels, + vocos_out_channels, + vocos_num_layers, + vocos_istft_config, + gin_channels, + is_onnx, + ) + else: + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) + + if use_sdp: + self.dp = StochasticDurationPredictor( + hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels + ) + else: + self.dp = DurationPredictor( + hidden_channels, 256, 3, 0.5, gin_channels=gin_channels + ) + + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True + ) # [b, 1, t_s] + neg_cent2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r + ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul( + z_p.transpose(1, 2), (m_p * s_p_sq_r) + ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True + ) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = ( + monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) + .unsqueeze(1) + .detach() + ) + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum( + x_mask + ) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) + o = self.dec(z_slice, g=g) + return ( + o, + l_length, + attn, + ids_slice, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def infer( + self, + x, + x_lengths, + sid=None, + noise_scale=1, + length_scale=1, + noise_scale_w=1.0, + max_len=None, + ): + t1 = time.time() + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + t2 = time.time() + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + t3 = time.time() + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( + x_mask.dtype + ) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + t4 = time.time() + z = self.flow(z_p, y_mask, g=g, reverse=True) + t5 = time.time() + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + t6 = time.time() + print( + "TextEncoder: {}s DurationPredictor: {}s Flow: {}s Decoder: {}s".format( + round(t2 - t1, 3), + round(t3 - t2, 3), + round(t5 - t4, 3), + round(t6 - t5, 3), + ) + ) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def export_forward(self, x, x_lengths, scales, sid): + # shape of scales: Bx3, make triton happy + audio, *_ = self.infer( + x, + x_lengths, + sid, + noise_scale=scales[0][0], + length_scale=scales[0][1], + noise_scale_w=scales[0][2], + ) + return audio + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) diff --git a/wetts/vits/model/modules.py b/wetts/vits/model/modules.py new file mode 100644 index 0000000..06c6221 --- /dev/null +++ b/wetts/vits/model/modules.py @@ -0,0 +1,98 @@ +import torch +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + +from utils import commons + + +LRELU_SLOPE = 0.1 + + +class WN(nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = nn.ModuleList() + self.res_skip_layers = nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + nn.utils.remove_weight_norm(l) + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x diff --git a/wetts/vits/model/normalization.py b/wetts/vits/model/normalization.py new file mode 100644 index 0000000..7fa0a80 --- /dev/null +++ b/wetts/vits/model/normalization.py @@ -0,0 +1,18 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) diff --git a/wetts/vits/models.py b/wetts/vits/models.py deleted file mode 100644 index 085ce9e..0000000 --- a/wetts/vits/models.py +++ /dev/null @@ -1,842 +0,0 @@ -import math -import time - -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn import Conv1d, ConvTranspose1d, Conv2d -from torch.nn.utils import remove_weight_norm, spectral_norm -from torch.nn.utils.parametrizations import weight_norm -from torchaudio.transforms import InverseSpectrogram -import monotonic_align - -import commons -import modules -import attentions -from commons import init_weights, get_padding -from stft import OnnxSTFT - - -class StochasticDurationPredictor(nn.Module): - def __init__( - self, - in_channels, - filter_channels, - kernel_size, - p_dropout, - n_flows=4, - gin_channels=256, - ): - super().__init__() - filter_channels = in_channels # it needs to be removed from future version. - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.log_flow = modules.Log() - self.flows = nn.ModuleList() - self.flows.append(modules.ElementwiseAffine(2)) - for i in range(n_flows): - self.flows.append( - modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) - ) - self.flows.append(modules.Flip()) - - self.post_pre = nn.Conv1d(1, filter_channels, 1) - self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = modules.DDSConv( - filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout - ) - self.post_flows = nn.ModuleList() - self.post_flows.append(modules.ElementwiseAffine(2)) - for i in range(4): - self.post_flows.append( - modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) - ) - self.post_flows.append(modules.Flip()) - - self.pre = nn.Conv1d(in_channels, filter_channels, 1) - self.proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = modules.DDSConv( - filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout - ) - self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): - x = torch.detach(x) - x = self.pre(x) - g = torch.detach(g) - x = x + self.cond(g) - x = self.convs(x, x_mask) - x = self.proj(x) * x_mask - - if not reverse: - flows = self.flows - assert w is not None - - logdet_tot_q = 0 - h_w = self.post_pre(w) - h_w = self.post_convs(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = ( - torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) - * x_mask - ) - z_q = e_q - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum( - (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] - ) - logq = ( - torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - - logdet_tot_q - ) - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in flows: - z, logdet = flow(z, x_mask, g=x, reverse=reverse) - logdet_tot = logdet_tot + logdet - nll = ( - torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - - logdet_tot - ) - return nll + logq # [b] - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = ( - torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) - * noise_scale - ) - for flow in flows: - z = flow(z, x_mask, g=x, reverse=reverse) - z0, z1 = torch.split(z, [1, 1], 1) - logw = z0 - return logw - - -class DurationPredictor(nn.Module): - def __init__( - self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels - ): - super().__init__() - - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.gin_channels = gin_channels - - self.drop = nn.Dropout(p_dropout) - self.conv_1 = nn.Conv1d( - in_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d( - filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_2 = modules.LayerNorm(filter_channels) - self.proj = nn.Conv1d(filter_channels, 1, 1) - self.cond = nn.Conv1d(gin_channels, in_channels, 1) - - def forward(self, x, x_mask, g=None): - x = torch.detach(x) - g = torch.detach(g) - x = x + self.cond(g) - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class TextEncoder(nn.Module): - def __init__( - self, - n_vocab, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - ): - super().__init__() - self.n_vocab = n_vocab - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - - self.encoder = attentions.Encoder( - hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths): - x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) - - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return x, m, logs, x_mask - - -class ResidualCouplingBlock(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=256, - ): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append( - modules.ResidualCouplingLayer( - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - mean_only=True, - ) - ) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x - - def remove_weight_norm(self): - for i, l in enumerate(self.flows): - if i % 2 == 0: - l.remove_weight_norm() - - -class PosteriorEncoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class Generator(torch.nn.Module): - def __init__( - self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels, - ): - super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d( - initial_channel, upsample_initial_channel, 7, 1, padding=3 - ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x, g=None): - x = self.conv_pre(x) - x = x + self.cond(g) - - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - -class ConvNeXtLayer(nn.Module): - def __init__(self, channels, h_channels, scale): - super().__init__() - self.dw_conv = nn.Conv1d( - channels, - channels, - kernel_size=3, - padding=1, - groups=channels, - ) - self.norm = modules.LayerNorm(channels) - self.pw_conv1 = nn.Conv1d(channels, h_channels, 1) - self.pw_conv2 = nn.Conv1d(h_channels, channels, 1) - self.scale = nn.Parameter( - torch.full(size=(1, channels, 1), fill_value=scale), requires_grad=True - ) - - def forward(self, x): - res = x - x = self.dw_conv(x) - x = self.norm(x) - x = self.pw_conv1(x) - x = F.gelu(x) - x = self.pw_conv2(x) - x = self.scale * x - x = res + x - return x - -class VocosGenerator(nn.Module): - def __init__( - self, - in_channels, - channels, - h_channels, - out_channels, - num_layers, - istft_config, - gin_channels, - is_onnx=False - ): - super().__init__() - - self.pad = nn.ReflectionPad1d([1, 0]) - self.in_conv = nn.Conv1d(in_channels, channels, kernel_size=1, padding=0) - self.cond = Conv1d(gin_channels, channels, 1) - self.norm_pre = modules.LayerNorm(channels) - scale = 1 / num_layers - self.layers = nn.ModuleList( - [ConvNeXtLayer(channels, h_channels, scale) for _ in range(num_layers)] - ) - self.norm_post = modules.LayerNorm(channels) - self.out_conv = nn.Conv1d(channels, out_channels, kernel_size=1) - self.is_onnx = is_onnx - - if self.is_onnx: - self.stft = OnnxSTFT(filter_length=istft_config['n_fft'], - hop_length=istft_config['hop_length'], - win_length=istft_config['win_length']) - else: - self.istft = InverseSpectrogram(**istft_config) - - def forward(self, x, g=None): - x = self.pad(x) - x = self.in_conv(x) + self.cond(g) - x = self.norm_pre(x) - for layer in self.layers: - x = layer(x) - x = self.norm_post(x) - x = self.out_conv(x) - mag, phase = x.chunk(2, dim=1) - mag = mag.exp().clamp_max(max=1e2) - if self.is_onnx: - o = self.stft.inverse(mag, phase).to(x.device) - else: - s = mag * (phase.cos() + 1j * phase.sin()) - o = self.istft(s).unsqueeze(1) - return o - - def remove_weight_norm(self): - pass - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f( - Conv2d( - 1, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 32, - 128, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 128, - 512, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 512, - 1024, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 1024, - 1024, - (kernel_size, 1), - 1, - padding=(get_padding(kernel_size, 1), 0), - ) - ), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminator, self).__init__() - periods = [2, 3, 5, 7, 11] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [ - DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods - ] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__( - self, - n_vocab, - spec_channels, - segment_size, - inter_channels=192, - hidden_channels=192, - filter_channels=768, - n_heads=2, - n_layers=6, - kernel_size=3, - p_dropout=0.1, - resblock="1", - resblock_kernel_sizes=[3, 7, 11], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_rates=[8, 8, 2, 2], - upsample_initial_channel=512, - upsample_kernel_sizes=[16, 16, 4, 4], - n_speakers=1, - gin_channels=256, - use_sdp=True, - vocoder_type="hifigan", - vocos_channels=512, - vocos_h_channels=1536, - vocos_out_channels=1026, - vocos_num_layers=8, - vocos_istft_config={ - "n_fft": 1024, - "hop_length": 256, - "win_length": 1024, - "center": True, - }, - is_onnx=False, - **kwargs - ): - super().__init__() - self.n_vocab = n_vocab - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.n_speakers = n_speakers - self.gin_channels = gin_channels - self.use_sdp = use_sdp - - self.enc_p = TextEncoder( - n_vocab, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - ) - if vocoder_type == "vocos": - self.dec = VocosGenerator( - inter_channels, - vocos_channels, - vocos_h_channels, - vocos_out_channels, - vocos_num_layers, - vocos_istft_config, - gin_channels, - is_onnx, - ) - else: - self.dec = Generator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - gin_channels=gin_channels, - ) - self.flow = ResidualCouplingBlock( - inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels - ) - - if use_sdp: - self.dp = StochasticDurationPredictor( - hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels - ) - else: - self.dp = DurationPredictor( - hidden_channels, 256, 3, 0.5, gin_channels=gin_channels - ) - - self.emb_g = nn.Embedding(n_speakers, gin_channels) - - def forward(self, x, x_lengths, y, y_lengths, sid=None): - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - - with torch.no_grad(): - # negative cross-entropy - s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] - neg_cent1 = torch.sum( - -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True - ) # [b, 1, t_s] - neg_cent2 = torch.matmul( - -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r - ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent3 = torch.matmul( - z_p.transpose(1, 2), (m_p * s_p_sq_r) - ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent4 = torch.sum( - -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True - ) # [b, 1, t_s] - neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 - - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = ( - monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) - .unsqueeze(1) - .detach() - ) - - w = attn.sum(2) - if self.use_sdp: - l_length = self.dp(x, x_mask, w, g=g) - l_length = l_length / torch.sum(x_mask) - else: - logw_ = torch.log(w + 1e-6) * x_mask - logw = self.dp(x, x_mask, g=g) - l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum( - x_mask - ) # for averaging - - # expand prior - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) - - z_slice, ids_slice = commons.rand_slice_segments( - z, y_lengths, self.segment_size - ) - o = self.dec(z_slice, g=g) - return ( - o, - l_length, - attn, - ids_slice, - x_mask, - y_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) - - def infer( - self, - x, - x_lengths, - sid=None, - noise_scale=1, - length_scale=1, - noise_scale_w=1.0, - max_len=None, - ): - t1 = time.time() - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - t2 = time.time() - if self.use_sdp: - logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) - else: - logw = self.dp(x, x_mask, g=g) - t3 = time.time() - w = torch.exp(logw) * x_mask * length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( - x_mask.dtype - ) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = commons.generate_path(w_ceil, attn_mask) - - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - t4 = time.time() - z = self.flow(z_p, y_mask, g=g, reverse=True) - t5 = time.time() - o = self.dec((z * y_mask)[:, :, :max_len], g=g) - t6 = time.time() - print( - "TextEncoder: {}s DurationPredictor: {}s Flow: {}s Decoder: {}s".format( - round(t2 - t1, 3), - round(t3 - t2, 3), - round(t5 - t4, 3), - round(t6 - t5, 3), - ) - ) - return o, attn, y_mask, (z, z_p, m_p, logs_p) - - def export_forward(self, x, x_lengths, scales, sid): - # shape of scales: Bx3, make triton happy - audio, *_ = self.infer( - x, - x_lengths, - sid, - noise_scale=scales[0][0], - length_scale=scales[0][1], - noise_scale_w=scales[0][2], - ) - return audio - - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): - g_src = self.emb_g(sid_src).unsqueeze(-1) - g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.dec(z_hat * y_mask, g=g_tgt) - return o_hat, y_mask, (z, z_p, z_hat) diff --git a/wetts/vits/modules.py b/wetts/vits/modules.py deleted file mode 100644 index 8007c0e..0000000 --- a/wetts/vits/modules.py +++ /dev/null @@ -1,511 +0,0 @@ -import math - -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn import Conv1d -from torch.nn.utils import remove_weight_norm -from torch.nn.utils.parametrizations import weight_norm - -import commons -from commons import init_weights, get_padding -from transforms import piecewise_rational_quadratic_transform - -LRELU_SLOPE = 0.1 - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - -class ConvReluNorm(nn.Module): - def __init__( - self, - in_channels, - hidden_channels, - out_channels, - kernel_size, - n_layers, - p_dropout, - ): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - assert n_layers > 1, "Number of layers should be larger than 0." - - self.conv_layers = nn.ModuleList() - self.norm_layers = nn.ModuleList() - self.conv_layers.append( - nn.Conv1d( - in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) - for _ in range(n_layers - 1): - self.conv_layers.append( - nn.Conv1d( - hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2, - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask - - -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append( - nn.Conv1d( - channels, - channels, - kernel_size, - groups=channels, - dilation=dilation, - padding=padding, - ) - ) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - -class WN(torch.nn.Module): - def __init__( - self, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels, - p_dropout=0, - ): - super(WN, self).__init__() - assert kernel_size % 2 == 1 - self.hidden_channels = hidden_channels - self.kernel_size = (kernel_size,) - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = p_dropout - - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.drop = nn.Dropout(p_dropout) - - cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) - self.cond_layer = weight_norm(cond_layer, name="weight") - - for i in range(n_layers): - dilation = dilation_rate**i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d( - hidden_channels, - 2 * hidden_channels, - kernel_size, - dilation=dilation, - padding=padding, - ) - in_layer = weight_norm(in_layer, name="weight") - self.in_layers.append(in_layer) - - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels - - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = weight_norm(res_skip_layer, name="weight") - self.res_skip_layers.append(res_skip_layer) - - def forward(self, x, x_mask, g=None, **kwargs): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) - - g = self.cond_layer(g) - - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] - - acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) - acts = self.drop(acts) - - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:, : self.hidden_channels, :] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:, self.hidden_channels :, :] - else: - output = output + res_skip_acts - return output * x_mask - - def remove_weight_norm(self): - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) - - -class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x, x_mask=None): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c2(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - - def forward(self, x, x_mask=None): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - - -class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x - - -class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels, 1)) - self.logs = nn.Parameter(torch.zeros(channels, 1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - -class ResidualCouplingLayer(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=256, - mean_only=False, - ): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only - - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=p_dropout, - gin_channels=gin_channels, - ) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels] * 2, 1) - else: - m = stats - logs = torch.zeros_like(m) - - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1, 2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x - - def remove_weight_norm(self): - self.enc.remove_weight_norm() - - -class ConvFlow(nn.Module): - def __init__( - self, - in_channels, - filter_channels, - kernel_size, - n_layers, - num_bins=10, - tail_bound=5.0, - ): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.num_bins = num_bins - self.tail_bound = tail_bound - self.half_channels = in_channels // 2 - - self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) - self.proj = nn.Conv1d( - filter_channels, self.half_channels * (num_bins * 3 - 1), 1 - ) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) - h = self.convs(h, x_mask, g=g) - h = self.proj(h) * x_mask - - b, c, t = x0.shape - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] - - unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( - self.filter_channels - ) - unnormalized_derivatives = h[..., 2 * self.num_bins :] - - x1, logabsdet = piecewise_rational_quadratic_transform( - x1, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=reverse, - tails="linear", - tail_bound=self.tail_bound, - ) - - x = torch.cat([x0, x1], 1) * x_mask - logdet = torch.sum(logabsdet * x_mask, [1, 2]) - if not reverse: - return x, logdet - else: - return x diff --git a/wetts/vits/train.py b/wetts/vits/train.py index 50a43af..59c4cf5 100644 --- a/wetts/vits/train.py +++ b/wetts/vits/train.py @@ -8,26 +8,23 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp import autocast, GradScaler -import commons -import utils from data_utils import ( TextAudioSpeakerLoader, TextAudioSpeakerCollate, DistributedBucketSampler, ) -from models import ( - SynthesizerTrn, - MultiPeriodDiscriminator, -) +from model.discriminators import MultiPeriodDiscriminator +from model.models import SynthesizerTrn from losses import generator_loss, discriminator_loss, feature_loss, kl_loss -from mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from utils import commons, task +from utils.mel_processing import mel_spectrogram_torch, spec_to_mel_torch torch.backends.cudnn.benchmark = False global_step = 0 def main(): - hps = utils.get_hparams() + hps = task.get_hparams() torch.manual_seed(hps.train.seed) global global_step world_size = int(os.environ.get('WORLD_SIZE', 1)) @@ -36,7 +33,7 @@ def main(): torch.torch.cuda.set_device(local_rank) dist.init_process_group("nccl") if rank == 0: - logger = utils.get_logger(hps.model_dir) + logger = task.get_logger(hps.model_dir) logger.info(hps) writer = SummaryWriter(log_dir=hps.model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) @@ -94,11 +91,11 @@ def main(): net_d = DDP(net_d, device_ids=[rank]) try: - _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g + _, _, _, epoch_str = task.load_checkpoint( + task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g ) - _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d + _, _, _, epoch_str = task.load_checkpoint( + task.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d ) global_step = (epoch_str - 1) * len(train_loader) except Exception as e: @@ -281,20 +278,20 @@ def train_and_evaluate( {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} ) image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy( + "slice/mel_org": task.plot_spectrogram_to_numpy( y_mel[0].data.cpu().numpy() ), - "slice/mel_gen": utils.plot_spectrogram_to_numpy( + "slice/mel_gen": task.plot_spectrogram_to_numpy( y_hat_mel[0].data.cpu().numpy() ), - "all/mel": utils.plot_spectrogram_to_numpy( + "all/mel": task.plot_spectrogram_to_numpy( mel[0].data.cpu().numpy() ), - "all/attn": utils.plot_alignment_to_numpy( + "all/attn": task.plot_alignment_to_numpy( attn[0, 0].data.cpu().numpy() ), } - utils.summarize( + task.summarize( writer=writer, global_step=global_step, images=image_dict, @@ -303,14 +300,14 @@ def train_and_evaluate( if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint( + task.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), ) - utils.save_checkpoint( + task.save_checkpoint( net_d, optim_d, hps.train.learning_rate, @@ -369,16 +366,16 @@ def evaluate(hps, generator, eval_loader, writer_eval): hps.data.win_length, ) image_dict = { - "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) + "gen/mel": task.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) } audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]} if global_step == 0: image_dict.update( - {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())} + {"gt/mel": task.plot_spectrogram_to_numpy(mel[0].cpu().numpy())} ) audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]}) - utils.summarize( + task.summarize( writer=writer_eval, global_step=global_step, images=image_dict, diff --git a/wetts/vits/commons.py b/wetts/vits/utils/commons.py similarity index 100% rename from wetts/vits/commons.py rename to wetts/vits/utils/commons.py diff --git a/wetts/vits/mel_processing.py b/wetts/vits/utils/mel_processing.py similarity index 100% rename from wetts/vits/mel_processing.py rename to wetts/vits/utils/mel_processing.py diff --git a/wetts/vits/stft.py b/wetts/vits/utils/stft.py similarity index 100% rename from wetts/vits/stft.py rename to wetts/vits/utils/stft.py diff --git a/wetts/vits/utils.py b/wetts/vits/utils/task.py similarity index 100% rename from wetts/vits/utils.py rename to wetts/vits/utils/task.py diff --git a/wetts/vits/transforms.py b/wetts/vits/utils/transforms.py similarity index 100% rename from wetts/vits/transforms.py rename to wetts/vits/utils/transforms.py