Skip to content

Commit

Permalink
[vits] Support wavlm discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
Shengqiang-Li authored and ShengqiangLi committed Mar 30, 2024
1 parent 4dd2794 commit 0ab0cb8
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 15 deletions.
8 changes: 7 additions & 1 deletion examples/baker/configs/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256
"gin_channels": 256,
"use_wd": true,
"slm_model": "exp/slm/wavlm-base-plus",
"slm_sr": 16000,
"slm_hidden": 768,
"slm_nlayers": 13,
"slm_initial_channel": 64
}
}
8 changes: 7 additions & 1 deletion examples/baker/configs/vits2_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
"n_layers_q": 3,
"use_sdp": true,
"use_spectral_norm": false,
"gin_channels": 256
"gin_channels": 256,
"use_wd": true,
"slm_model": "exp/slm/wavlm-base-plus",
"slm_sr": 16000,
"slm_hidden": 768,
"slm_nlayers": 13,
"slm_initial_channel": 64
}
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ torch
torchvision
tqdm
transformers
huggingface_hub
95 changes: 95 additions & 0 deletions wetts/vits/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import torchaudio
from transformers import AutoModel


def feature_loss(fmap_r, fmap_g):
Expand Down Expand Up @@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l


class WavLMLoss(torch.nn.Module):
def __init__(self, model, wd, model_sr, slm_sr=16000):
super(WavLMLoss, self).__init__()
self.wavlm = AutoModel.from_pretrained(model)
self.wd = wd
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
self.wavlm.eval()
for param in self.wavlm.parameters():
param.requires_grad = False

def forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(
input_values=wav_16, output_hidden_states=True
).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(
input_values=y_rec_16.squeeze(), output_hidden_states=True
).hidden_states

floss = 0
for er, eg in zip(wav_embeddings, y_rec_embeddings):
floss += torch.mean(torch.abs(er - eg))

return floss.mean()

def generator(self, y_rec):
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(
input_values=y_rec_16, output_hidden_states=True
).hidden_states
y_rec_embeddings = (
torch.stack(y_rec_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)
y_df_hat_g = self.wd(y_rec_embeddings)
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)

return loss_gen

def discriminator(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(
input_values=wav_16, output_hidden_states=True
).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(
input_values=y_rec_16, output_hidden_states=True
).hidden_states

y_embeddings = (
torch.stack(wav_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)
y_rec_embeddings = (
torch.stack(y_rec_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)

y_d_rs = self.wd(y_embeddings)
y_d_gs = self.wd(y_rec_embeddings)

y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs

r_loss = torch.mean((1 - y_df_hat_r) ** 2)
g_loss = torch.mean((y_df_hat_g) ** 2)

loss_disc_f = r_loss + g_loss

return loss_disc_f.mean()

def discriminator_forward(self, wav):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(
input_values=wav_16, output_hidden_states=True
).hidden_states
y_embeddings = (
torch.stack(wav_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)

y_d_rs = self.wd(y_embeddings)

return y_d_rs
49 changes: 49 additions & 0 deletions wetts/vits/model/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,52 @@ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
output_probs.append([output_prob])

return output_probs


class WavLMDiscriminator(nn.Module):
"""docstring for Discriminator."""

def __init__(
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
):
super(WavLMDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.pre = norm_f(
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
)

self.convs = nn.ModuleList(
[
norm_f(
nn.Conv1d(
initial_channel, initial_channel * 2, kernel_size=5, padding=2
)
),
norm_f(
nn.Conv1d(
initial_channel * 2,
initial_channel * 4,
kernel_size=5,
padding=2,
)
),
norm_f(
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
),
]
)

self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))

def forward(self, x):
x = self.pre(x)

fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
x = torch.flatten(x, 1, -1)

return x
Loading

0 comments on commit 0ab0cb8

Please sign in to comment.