Skip to content

Commit

Permalink
Release
Browse files Browse the repository at this point in the history
  • Loading branch information
heatz123 committed Feb 1, 2023
1 parent 65fcc3d commit 8c0292d
Show file tree
Hide file tree
Showing 55 changed files with 61,101 additions and 2 deletions.
103 changes: 101 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,101 @@
# naturalspeech
Implementation of NaturalSpeech(2022)
# NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality

This is an implementation of Microsoft's [NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality](https://arxiv.org/abs/2205.04421) in Pytorch.

Contribution and pull requests are highly appreciated!

23.02.01: Pretrained models or demo samples will soon be released.


### Overview

![figure1](resources/figure1.png)

Naturalspeech is a VAE-based model that employs several techniques to improve the prior and simplify the posterior. It differs from VITS in several ways, including:
- **Phoneme pre-training**: Naturalspeech uses a pre-trained phoneme encoder on a large text corpus, obtained through masked language modeling on phoneme sequences.
- **Differentiable durator**: The posterior operates at the frame level, while the prior operates at the phoneme level. Naturalspeech uses a differentiable durator to bridge the length difference, resulting in soft and flexible features that are expanded.
- **Bidirectional Prior/Posterior**: Naturalspeech reduces the posterior and enhances the prior through normalizing flow, which maps in both directions with forward and backward loss.
- **Memory-based VAE**: The prior is further enhanced through a memory bank using Q-K-V attention."


### Notes
- This implementation does not include pre-training of phonemes using a large-scale text corpus from the news-crawl dataset.
- The multiplier for each loss can be adjusted in the configuration file. Using losses without a multiplier may not lead to convergence.
- The tuning stage for the last 2k epochs has been omitted.
- Due to the high VRAM usage of the soft-dtw loss, there is an option to use a non-softdtw loss for memory efficiency.
- For the soft-dtw loss, the warp factor has been set to 134.4 (0.07 * 192) to match the non-softdtw loss, instead of 0.07.
- To train the duration predictor in the warm-up stage, duration labels are required. The paper suggests using any tool to provide the duration label. In this implementation, a pre-trained VITS model was used.
- To further improve memory efficiency during training, randomly silced sequences are fed to the decoder as in the VITS model.




### How to train

0.
```
# python >= 3.6
pip install -r requirements.txt
```
1. clone this repository
1. download `The LJ Speech Dataset`: [link](https://keithito.com/LJ-Speech-Dataset/)
1. create symbolic link to ljspeech dataset:
```
ln -s /path/to/LJSpeech-1.1/wavs/ DUMMY1
```
1. text preprocessing (optional, if you are using custom dataset):
1. `apt-get install espeak`
2.
```
python preprocess.py --text_index 1 --filelists filelists/ljs_audio_text_train_filelist.txt filelists/ljs_audio_text_val_filelist.txt filelists/ljs_audio_text_test_filelist.txt
```
1. duration preprocessing (obtain duration labels using pretrained VITS):
1. `git clone https://github.com/jaywalnut310/vits.git; cd vits`
2. create symbolic link to ljspeech dataset
```
ln -s /path/to/LJSpeech-1.1/wavs/ DUMMY1
```
3. download pretrained VITS model described as from VITS official github: [github link](https://github.com/jaywalnut310/vits) / [pretrained models](https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2)
4. setup monotonic alignment search (for VITS inference):
```
cd monotonic_align; mkdir monotonic_align; python setup.py build_ext --inplace; cd ..
```
5. copy duration preprocessing script to VITS repo: `cp /path/to/naturalspeech/preprocess_durations.py .`
6.
```
python3 preprocess_durations.py --weights_path ./pretrained_ljs.pth --filelists filelists/ljs_audio_text_train_filelist.txt.cleaned filelists/ljs_audio_text_val_filelist.txt.cleaned filelists/ljs_audio_text_test_filelist.txt.cleaned
```
7. once the duration labels are created, copy the labels to the naturalspeech repo: `cp -r durations/ path/to/naturalspeech`
1. train (warmup)
```
python3 train.py -c configs/ljs.json -m [run_name] --warmup
```
Note here that ljs.json is for low-resource training, which runs for 1500 epochs and does not use soft-dtw loss. If you want to reproduce the steps stated in the paper, use ljs_reproduce.json, which runs for 15000 epochs and uses soft-dtw loss.
1. initialize and attach memory bank after warmup:
```
python3 attach_memory_bank.py -c configs/ljs.json --weights_path logs/[run_name]/G_xxx.pth
```
if you lack memory, you can specify the "--num_samples" argument to use only a subset of samples.
1. train (resume)
```
python3 train.py -c configs/ljs.json -m [run_name]
```
You can use tensorboard to monitor the training.
```
tensorboard --logdir /path/to/naturalspeech/logs
```
During each evaluation phase, a selection of samples from the test set is evaluated and saved in the `logs/[run_name]/eval` directory.
## References
- [VITS implemetation](https://github.com/jaywalnut310/vits) by @jaywalnut310 for normalizing flows, phoneme encoder, and hifi-gan decoder implementation
- [Parallel Tacotron 2 Implementation](https://github.com/keonlee9420/Parallel-Tacotron2) by @keonlee9420 for learnable upsampling Layer
- [soft-dtw implementation](https://github.com/Maghoumi/pytorch-softdtw-cuda) by @Maghoumi for sdtw loss
195 changes: 195 additions & 0 deletions attach_memory_bank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os
import argparse
from pathlib import Path

import numpy as np
import torch
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader

from text.symbols import symbols
from models.models import SynthesizerTrn
from models.models import VAEMemoryBank
from utils import utils

from utils.data_utils import (
TextAudioLoaderWithDuration,
TextAudioCollateWithDuration,
)

from sklearn.cluster import KMeans


def load_net_g(hps, weights_path):
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
hps.models,
).cuda()

optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)

def load_checkpoint(checkpoint_path, model, optimizer=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]

if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]

state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
print("%s is not in the checkpoint" % k)
new_state_dict[k] = v
model.load_state_dict(new_state_dict)

print(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration

model, optimizer, learning_rate, iteration = load_checkpoint(
weights_path, net_g, optim_g
)

return model, optimizer, learning_rate, iteration


def get_dataloader(hps):
train_dataset = TextAudioLoaderWithDuration(hps.data.training_files, hps.data)
collate_fn = TextAudioCollateWithDuration()
train_loader = DataLoader(
train_dataset,
num_workers=1,
shuffle=False,
pin_memory=False,
collate_fn=collate_fn,
batch_size=1,
)
return train_loader


def get_zs(net_g, dataloader, num_samples=0):
net_g.eval()
print(len(dataloader))
zs = []
with torch.no_grad():
for batch_idx, (
x,
x_lengths,
spec,
spec_lengths,
y,
y_lengths,
duration,
) in enumerate(dataloader):
rank = 0
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
rank, non_blocking=True
)
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
)
duration = duration.cuda()
with autocast(enabled=hps.train.fp16_run):
(
y_hat,
l_length,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q, p_mask),
*_,
) = net_g(x, x_lengths, spec, spec_lengths, duration)

zs.append(z.squeeze(0).cpu())
if batch_idx % 100 == 99:
print(batch_idx, zs[batch_idx].shape)

if num_samples and batch_idx >= num_samples:
break
return zs


def k_means(zs):
X = torch.cat(zs, dim=1).transpose(0, 1).numpy()
print(X.shape)
kmeans = KMeans(n_clusters=1000, random_state=0, n_init="auto").fit(X)
print(kmeans.cluster_centers_.shape)

return kmeans.cluster_centers_


def save_memory_bank(bank):
state_dict = bank.state_dict()
torch.save(state_dict, "./bank_init.pth")


def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
print("Saving model to " + checkpoint_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="configs/ljs.json")
parser.add_argument("--weights_path", type=str)
parser.add_argument(
"--num_samples",
type=int,
default=0,
help="samples to use for k-means clustering, 0 for use all samples in dataset",
)
args = parser.parse_args()

hps = utils.get_hparams_from_file(args.config)
net_g, optimizer, lr, iterations = load_net_g(hps, weights_path=args.weights_path)

dataloader = get_dataloader(hps)
zs = get_zs(net_g, dataloader, num_samples=args.num_samples)
centers = k_means(zs)

memory_bank = VAEMemoryBank(
**hps.models.memory_bank,
init_values=torch.from_numpy(centers).cuda().transpose(0, 1)
)
save_memory_bank(memory_bank)

net_g.memory_bank = memory_bank
optimizer.add_param_group(
{
"params": list(memory_bank.parameters()),
"initial_lr": optimizer.param_groups[0]["initial_lr"],
}
)

p = Path(args.weights_path)
save_path = p.with_stem(p.stem + "_with_memory").__str__()
save_checkpoint(net_g, optimizer, lr, iterations, save_path)

# test
print(memory_bank(torch.randn((2, 192, 12))).shape)
95 changes: 95 additions & 0 deletions configs/ljs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"train": {
"log_interval": 500,
"eval_interval": 5,
"seed": 1234,
"epochs": 1500,
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 16,
"fp16_run": true,
"lr_decay": 0.999,
"segment_size": 8192,
"init_lr_ratio": 1,
"warmup_epochs": 200,
"c_mel": 45,
"c_kl": 1.0,
"c_kl_fwd": 0.001,
"c_e2e": 0.1,
"c_dur": 5.0,
"use_sdtw": false,
"use_gt_duration": true
},
"data": {
"training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned",
"validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned",
"text_cleaners":["english_cleaners2"],
"max_wav_value": 32768.0,
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 0,
"cleaned_text": true
},
"models": {
"phoneme_encoder": {
"out_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1
},
"decoder": {
"initial_channel": 192,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [8,8,2,2],
"upsample_initial_channel": 256,
"upsample_kernel_sizes": [16,16,4,4],
"gin_channels": 0
},
"posterior_encoder": {
"out_channels": 192,
"hidden_channels": 192,
"kernel_size": 5,
"dilation_rate": 1,
"n_layers": 16
},
"flow": {
"channels": 192,
"hidden_channels": 192,
"kernel_size": 5,
"dilation_rate": 1,
"n_layers": 4
},
"duration_predictor": {
"in_channels": 192,
"filter_channels": 256,
"kernel_size": 3,
"p_dropout": 0.5
},
"learnable_upsampling": {
"d_predictor": 192,
"kernel_size": 3,
"dropout": 0.0,
"conv_output_size": 8,
"dim_w": 4,
"dim_c": 2,
"max_seq_len": 1000
},
"memory_bank": {
"bank_size": 1000,
"n_hidden_dims": 192,
"n_attn_heads": 2
}
}
}
Loading

0 comments on commit 8c0292d

Please sign in to comment.