-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
55 changed files
with
61,101 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,101 @@ | ||
# naturalspeech | ||
Implementation of NaturalSpeech(2022) | ||
# NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality | ||
|
||
This is an implementation of Microsoft's [NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality](https://arxiv.org/abs/2205.04421) in Pytorch. | ||
|
||
Contribution and pull requests are highly appreciated! | ||
|
||
23.02.01: Pretrained models or demo samples will soon be released. | ||
|
||
|
||
### Overview | ||
|
||
 | ||
|
||
Naturalspeech is a VAE-based model that employs several techniques to improve the prior and simplify the posterior. It differs from VITS in several ways, including: | ||
- **Phoneme pre-training**: Naturalspeech uses a pre-trained phoneme encoder on a large text corpus, obtained through masked language modeling on phoneme sequences. | ||
- **Differentiable durator**: The posterior operates at the frame level, while the prior operates at the phoneme level. Naturalspeech uses a differentiable durator to bridge the length difference, resulting in soft and flexible features that are expanded. | ||
- **Bidirectional Prior/Posterior**: Naturalspeech reduces the posterior and enhances the prior through normalizing flow, which maps in both directions with forward and backward loss. | ||
- **Memory-based VAE**: The prior is further enhanced through a memory bank using Q-K-V attention." | ||
|
||
|
||
### Notes | ||
- This implementation does not include pre-training of phonemes using a large-scale text corpus from the news-crawl dataset. | ||
- The multiplier for each loss can be adjusted in the configuration file. Using losses without a multiplier may not lead to convergence. | ||
- The tuning stage for the last 2k epochs has been omitted. | ||
- Due to the high VRAM usage of the soft-dtw loss, there is an option to use a non-softdtw loss for memory efficiency. | ||
- For the soft-dtw loss, the warp factor has been set to 134.4 (0.07 * 192) to match the non-softdtw loss, instead of 0.07. | ||
- To train the duration predictor in the warm-up stage, duration labels are required. The paper suggests using any tool to provide the duration label. In this implementation, a pre-trained VITS model was used. | ||
- To further improve memory efficiency during training, randomly silced sequences are fed to the decoder as in the VITS model. | ||
|
||
|
||
|
||
|
||
### How to train | ||
|
||
0. | ||
``` | ||
# python >= 3.6 | ||
pip install -r requirements.txt | ||
``` | ||
1. clone this repository | ||
1. download `The LJ Speech Dataset`: [link](https://keithito.com/LJ-Speech-Dataset/) | ||
1. create symbolic link to ljspeech dataset: | ||
``` | ||
ln -s /path/to/LJSpeech-1.1/wavs/ DUMMY1 | ||
``` | ||
1. text preprocessing (optional, if you are using custom dataset): | ||
1. `apt-get install espeak` | ||
2. | ||
``` | ||
python preprocess.py --text_index 1 --filelists filelists/ljs_audio_text_train_filelist.txt filelists/ljs_audio_text_val_filelist.txt filelists/ljs_audio_text_test_filelist.txt | ||
``` | ||
1. duration preprocessing (obtain duration labels using pretrained VITS): | ||
1. `git clone https://github.com/jaywalnut310/vits.git; cd vits` | ||
2. create symbolic link to ljspeech dataset | ||
``` | ||
ln -s /path/to/LJSpeech-1.1/wavs/ DUMMY1 | ||
``` | ||
3. download pretrained VITS model described as from VITS official github: [github link](https://github.com/jaywalnut310/vits) / [pretrained models](https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2) | ||
4. setup monotonic alignment search (for VITS inference): | ||
``` | ||
cd monotonic_align; mkdir monotonic_align; python setup.py build_ext --inplace; cd .. | ||
``` | ||
5. copy duration preprocessing script to VITS repo: `cp /path/to/naturalspeech/preprocess_durations.py .` | ||
6. | ||
``` | ||
python3 preprocess_durations.py --weights_path ./pretrained_ljs.pth --filelists filelists/ljs_audio_text_train_filelist.txt.cleaned filelists/ljs_audio_text_val_filelist.txt.cleaned filelists/ljs_audio_text_test_filelist.txt.cleaned | ||
``` | ||
7. once the duration labels are created, copy the labels to the naturalspeech repo: `cp -r durations/ path/to/naturalspeech` | ||
1. train (warmup) | ||
``` | ||
python3 train.py -c configs/ljs.json -m [run_name] --warmup | ||
``` | ||
Note here that ljs.json is for low-resource training, which runs for 1500 epochs and does not use soft-dtw loss. If you want to reproduce the steps stated in the paper, use ljs_reproduce.json, which runs for 15000 epochs and uses soft-dtw loss. | ||
1. initialize and attach memory bank after warmup: | ||
``` | ||
python3 attach_memory_bank.py -c configs/ljs.json --weights_path logs/[run_name]/G_xxx.pth | ||
``` | ||
if you lack memory, you can specify the "--num_samples" argument to use only a subset of samples. | ||
1. train (resume) | ||
``` | ||
python3 train.py -c configs/ljs.json -m [run_name] | ||
``` | ||
You can use tensorboard to monitor the training. | ||
``` | ||
tensorboard --logdir /path/to/naturalspeech/logs | ||
``` | ||
During each evaluation phase, a selection of samples from the test set is evaluated and saved in the `logs/[run_name]/eval` directory. | ||
## References | ||
- [VITS implemetation](https://github.com/jaywalnut310/vits) by @jaywalnut310 for normalizing flows, phoneme encoder, and hifi-gan decoder implementation | ||
- [Parallel Tacotron 2 Implementation](https://github.com/keonlee9420/Parallel-Tacotron2) by @keonlee9420 for learnable upsampling Layer | ||
- [soft-dtw implementation](https://github.com/Maghoumi/pytorch-softdtw-cuda) by @Maghoumi for sdtw loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import os | ||
import argparse | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
from torch.cuda.amp import autocast | ||
from torch.utils.data import DataLoader | ||
|
||
from text.symbols import symbols | ||
from models.models import SynthesizerTrn | ||
from models.models import VAEMemoryBank | ||
from utils import utils | ||
|
||
from utils.data_utils import ( | ||
TextAudioLoaderWithDuration, | ||
TextAudioCollateWithDuration, | ||
) | ||
|
||
from sklearn.cluster import KMeans | ||
|
||
|
||
def load_net_g(hps, weights_path): | ||
net_g = SynthesizerTrn( | ||
len(symbols), | ||
hps.data.filter_length // 2 + 1, | ||
hps.train.segment_size // hps.data.hop_length, | ||
hps.models, | ||
).cuda() | ||
|
||
optim_g = torch.optim.AdamW( | ||
net_g.parameters(), | ||
hps.train.learning_rate, | ||
betas=hps.train.betas, | ||
eps=hps.train.eps, | ||
) | ||
|
||
def load_checkpoint(checkpoint_path, model, optimizer=None): | ||
assert os.path.isfile(checkpoint_path) | ||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | ||
iteration = checkpoint_dict["iteration"] | ||
learning_rate = checkpoint_dict["learning_rate"] | ||
|
||
if optimizer is not None: | ||
optimizer.load_state_dict(checkpoint_dict["optimizer"]) | ||
saved_state_dict = checkpoint_dict["model"] | ||
|
||
state_dict = model.state_dict() | ||
new_state_dict = {} | ||
for k, v in state_dict.items(): | ||
try: | ||
new_state_dict[k] = saved_state_dict[k] | ||
except: | ||
print("%s is not in the checkpoint" % k) | ||
new_state_dict[k] = v | ||
model.load_state_dict(new_state_dict) | ||
|
||
print( | ||
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) | ||
) | ||
return model, optimizer, learning_rate, iteration | ||
|
||
model, optimizer, learning_rate, iteration = load_checkpoint( | ||
weights_path, net_g, optim_g | ||
) | ||
|
||
return model, optimizer, learning_rate, iteration | ||
|
||
|
||
def get_dataloader(hps): | ||
train_dataset = TextAudioLoaderWithDuration(hps.data.training_files, hps.data) | ||
collate_fn = TextAudioCollateWithDuration() | ||
train_loader = DataLoader( | ||
train_dataset, | ||
num_workers=1, | ||
shuffle=False, | ||
pin_memory=False, | ||
collate_fn=collate_fn, | ||
batch_size=1, | ||
) | ||
return train_loader | ||
|
||
|
||
def get_zs(net_g, dataloader, num_samples=0): | ||
net_g.eval() | ||
print(len(dataloader)) | ||
zs = [] | ||
with torch.no_grad(): | ||
for batch_idx, ( | ||
x, | ||
x_lengths, | ||
spec, | ||
spec_lengths, | ||
y, | ||
y_lengths, | ||
duration, | ||
) in enumerate(dataloader): | ||
rank = 0 | ||
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( | ||
rank, non_blocking=True | ||
) | ||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( | ||
rank, non_blocking=True | ||
) | ||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( | ||
rank, non_blocking=True | ||
) | ||
duration = duration.cuda() | ||
with autocast(enabled=hps.train.fp16_run): | ||
( | ||
y_hat, | ||
l_length, | ||
ids_slice, | ||
x_mask, | ||
z_mask, | ||
(z, z_p, m_p, logs_p, m_q, logs_q, p_mask), | ||
*_, | ||
) = net_g(x, x_lengths, spec, spec_lengths, duration) | ||
|
||
zs.append(z.squeeze(0).cpu()) | ||
if batch_idx % 100 == 99: | ||
print(batch_idx, zs[batch_idx].shape) | ||
|
||
if num_samples and batch_idx >= num_samples: | ||
break | ||
return zs | ||
|
||
|
||
def k_means(zs): | ||
X = torch.cat(zs, dim=1).transpose(0, 1).numpy() | ||
print(X.shape) | ||
kmeans = KMeans(n_clusters=1000, random_state=0, n_init="auto").fit(X) | ||
print(kmeans.cluster_centers_.shape) | ||
|
||
return kmeans.cluster_centers_ | ||
|
||
|
||
def save_memory_bank(bank): | ||
state_dict = bank.state_dict() | ||
torch.save(state_dict, "./bank_init.pth") | ||
|
||
|
||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): | ||
state_dict = model.state_dict() | ||
torch.save( | ||
{ | ||
"model": state_dict, | ||
"iteration": iteration, | ||
"optimizer": optimizer.state_dict(), | ||
"learning_rate": learning_rate, | ||
}, | ||
checkpoint_path, | ||
) | ||
print("Saving model to " + checkpoint_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-c", "--config", type=str, default="configs/ljs.json") | ||
parser.add_argument("--weights_path", type=str) | ||
parser.add_argument( | ||
"--num_samples", | ||
type=int, | ||
default=0, | ||
help="samples to use for k-means clustering, 0 for use all samples in dataset", | ||
) | ||
args = parser.parse_args() | ||
|
||
hps = utils.get_hparams_from_file(args.config) | ||
net_g, optimizer, lr, iterations = load_net_g(hps, weights_path=args.weights_path) | ||
|
||
dataloader = get_dataloader(hps) | ||
zs = get_zs(net_g, dataloader, num_samples=args.num_samples) | ||
centers = k_means(zs) | ||
|
||
memory_bank = VAEMemoryBank( | ||
**hps.models.memory_bank, | ||
init_values=torch.from_numpy(centers).cuda().transpose(0, 1) | ||
) | ||
save_memory_bank(memory_bank) | ||
|
||
net_g.memory_bank = memory_bank | ||
optimizer.add_param_group( | ||
{ | ||
"params": list(memory_bank.parameters()), | ||
"initial_lr": optimizer.param_groups[0]["initial_lr"], | ||
} | ||
) | ||
|
||
p = Path(args.weights_path) | ||
save_path = p.with_stem(p.stem + "_with_memory").__str__() | ||
save_checkpoint(net_g, optimizer, lr, iterations, save_path) | ||
|
||
# test | ||
print(memory_bank(torch.randn((2, 192, 12))).shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
{ | ||
"train": { | ||
"log_interval": 500, | ||
"eval_interval": 5, | ||
"seed": 1234, | ||
"epochs": 1500, | ||
"learning_rate": 2e-4, | ||
"betas": [0.8, 0.99], | ||
"eps": 1e-9, | ||
"batch_size": 16, | ||
"fp16_run": true, | ||
"lr_decay": 0.999, | ||
"segment_size": 8192, | ||
"init_lr_ratio": 1, | ||
"warmup_epochs": 200, | ||
"c_mel": 45, | ||
"c_kl": 1.0, | ||
"c_kl_fwd": 0.001, | ||
"c_e2e": 0.1, | ||
"c_dur": 5.0, | ||
"use_sdtw": false, | ||
"use_gt_duration": true | ||
}, | ||
"data": { | ||
"training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", | ||
"validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", | ||
"text_cleaners":["english_cleaners2"], | ||
"max_wav_value": 32768.0, | ||
"sampling_rate": 22050, | ||
"filter_length": 1024, | ||
"hop_length": 256, | ||
"win_length": 1024, | ||
"n_mel_channels": 80, | ||
"mel_fmin": 0.0, | ||
"mel_fmax": null, | ||
"add_blank": true, | ||
"n_speakers": 0, | ||
"cleaned_text": true | ||
}, | ||
"models": { | ||
"phoneme_encoder": { | ||
"out_channels": 192, | ||
"hidden_channels": 192, | ||
"filter_channels": 768, | ||
"n_heads": 2, | ||
"n_layers": 6, | ||
"kernel_size": 3, | ||
"p_dropout": 0.1 | ||
}, | ||
"decoder": { | ||
"initial_channel": 192, | ||
"resblock": "1", | ||
"resblock_kernel_sizes": [3,7,11], | ||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], | ||
"upsample_rates": [8,8,2,2], | ||
"upsample_initial_channel": 256, | ||
"upsample_kernel_sizes": [16,16,4,4], | ||
"gin_channels": 0 | ||
}, | ||
"posterior_encoder": { | ||
"out_channels": 192, | ||
"hidden_channels": 192, | ||
"kernel_size": 5, | ||
"dilation_rate": 1, | ||
"n_layers": 16 | ||
}, | ||
"flow": { | ||
"channels": 192, | ||
"hidden_channels": 192, | ||
"kernel_size": 5, | ||
"dilation_rate": 1, | ||
"n_layers": 4 | ||
}, | ||
"duration_predictor": { | ||
"in_channels": 192, | ||
"filter_channels": 256, | ||
"kernel_size": 3, | ||
"p_dropout": 0.5 | ||
}, | ||
"learnable_upsampling": { | ||
"d_predictor": 192, | ||
"kernel_size": 3, | ||
"dropout": 0.0, | ||
"conv_output_size": 8, | ||
"dim_w": 4, | ||
"dim_c": 2, | ||
"max_seq_len": 1000 | ||
}, | ||
"memory_bank": { | ||
"bank_size": 1000, | ||
"n_hidden_dims": 192, | ||
"n_attn_heads": 2 | ||
} | ||
} | ||
} |
Oops, something went wrong.