diff --git a/README.md b/README.md index fccbae6e..c288ca61 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,10 @@ In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis. ## 🚀 News +- **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911) and [Emilia](https://arxiv.org/abs/2407.05361) got accepted by IEEE SLT 2024! 🤗 +- **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.com/invite/ZxxREr3Y) to stay connected and engage with our community! +- **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉 +- **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑 - **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md) - **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md) - **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md) diff --git a/bins/vc/Noro/train.py b/bins/vc/Noro/train.py new file mode 100644 index 00000000..8418c5cd --- /dev/null +++ b/bins/vc/Noro/train.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import torch +from models.vc.Noro.noro_trainer import NoroTrainer +from utils.util import load_config + +def build_trainer(args, cfg): + supported_trainer = { + "VC": NoroTrainer, + } + trainer_class = supported_trainer[cfg.model_type] + trainer = trainer_class(args, cfg) + return trainer + + +def cuda_relevant(deterministic=False): + torch.cuda.empty_cache() + # TF32 on Ampere and above + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.enabled = True + torch.backends.cudnn.allow_tf32 = True + # Deterministic + torch.backends.cudnn.deterministic = deterministic + torch.backends.cudnn.benchmark = not deterministic + torch.use_deterministic_algorithms(deterministic) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="config.json", + help="json files for configurations.", + required=True, + ) + parser.add_argument( + "--exp_name", + type=str, + default="exp_name", + help="A specific name to note the experiment", + required=True, + ) + parser.add_argument( + "--resume", action="store_true", help="The model name to restore" + ) + parser.add_argument( + "--log_level", default="warning", help="logging level (debug, info, warning)" + ) + parser.add_argument( + "--resume_type", + type=str, + default="resume", + help="Resume training or finetuning.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Checkpoint for resume training or finetuning.", + ) + NoroTrainer.add_arguments(parser) + args = parser.parse_args() + cfg = load_config(args.config) + print("experiment name: ", args.exp_name) + # # CUDA settings + cuda_relevant() + # Build trainer + print(f"Building {cfg.model_type} trainer") + trainer = build_trainer(args, cfg) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + print(f"Start training {cfg.model_type} model") + trainer.train_loop() + + +if __name__ == "__main__": + main() diff --git a/config/noro.json b/config/noro.json new file mode 100644 index 00000000..3c9d5c12 --- /dev/null +++ b/config/noro.json @@ -0,0 +1,76 @@ +{ + "base_config": "config/base.json", + "model_type": "VC", + "dataset": ["mls"], + "model": { + "reference_encoder": { + "encoder_layer": 6, + "encoder_hidden": 512, + "encoder_head": 8, + "conv_filter_size": 2048, + "conv_kernel_size": 9, + "encoder_dropout": 0.2, + "use_skip_connection": false, + "use_new_ffn": true, + "ref_in_dim": 80, + "ref_out_dim": 512, + "use_query_emb": true, + "num_query_emb": 32 + }, + "diffusion": { + "beta_min": 0.05, + "beta_max": 20, + "sigma": 1.0, + "noise_factor": 1.0, + "ode_solve_method": "euler", + "diff_model_type": "WaveNet", + "diff_wavenet":{ + "input_size": 80, + "hidden_size": 512, + "out_size": 80, + "num_layers": 47, + "cross_attn_per_layer": 3, + "dilation_cycle": 2, + "attn_head": 8, + "drop_out": 0.2 + } + }, + "prior_encoder": { + "encoder_layer": 6, + "encoder_hidden": 512, + "encoder_head": 8, + "conv_filter_size": 2048, + "conv_kernel_size": 9, + "encoder_dropout": 0.2, + "use_skip_connection": false, + "use_new_ffn": true, + "vocab_size": 256, + "cond_dim": 512, + "duration_predictor": { + "input_size": 512, + "filter_size": 512, + "kernel_size": 3, + "conv_layers": 30, + "cross_attn_per_layer": 3, + "attn_head": 8, + "drop_out": 0.2 + }, + "pitch_predictor": { + "input_size": 512, + "filter_size": 512, + "kernel_size": 5, + "conv_layers": 30, + "cross_attn_per_layer": 3, + "attn_head": 8, + "drop_out": 0.5 + }, + "pitch_min": 50, + "pitch_max": 1100, + "pitch_bins_num": 512 + }, + "vc_feature": { + "content_feature_dim": 768, + "hidden_dim": 512 + } + } +} \ No newline at end of file diff --git a/config/tts.json b/config/tts.json index 882726db..0e804bda 100644 --- a/config/tts.json +++ b/config/tts.json @@ -19,7 +19,7 @@ "add_blank": true }, "model": { - "text_token_num": 512, + "text_token_num": 512 } } diff --git a/egs/vc/Noro/README.md b/egs/vc/Noro/README.md new file mode 100644 index 00000000..55c1340f --- /dev/null +++ b/egs/vc/Noro/README.md @@ -0,0 +1,122 @@ +# Noro: A Noise-Robust One-shot Voice Conversion System + +
+
+ +
+
+ +This is the official implementation of the paper: NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities. + +- The semantic extractor is from [Hubert](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert). +- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture. + +## Project Overview +Noro is a noise-robust one-shot voice conversion (VC) system designed to convert the timbre of speech from a source speaker to a target speaker using only a single reference speech sample, while preserving the semantic content of the original speech. Noro introduces innovative components tailored for VC using noisy reference speeches, including a dual-branch reference encoding module and a noise-agnostic contrastive speaker loss. + +## Features +- **Noise-Robust Voice Conversion**: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments. +- **One-shot Voice Conversion**: Achieves timbre conversion using only one reference speech sample. +- **Speaker Representation Learning**: Explores the potential of the reference encoder as a self-supervised speaker encoder. + +## Installation Requirement + +Set up your environment as in Amphion README (you'll need a conda environment, and we recommend using Linux). + +### Prepare Hubert Model + +Humbert checkpoint and kmeans can be downloaded [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert). +Set the downloded model path at `egs/vc/Noro/exp_config_base.json`. + + +## Usage + +### Download pretrained weights +You need to download our pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1NPzSIuSKO8o87g5ImNzpw_BgbhsZaxNg?usp=drive_link). + +### Inference +1. Configure inference parameters: + Modify the pretrained checkpoint path, source voice path and reference voice path at `egs/vc/Noro/noro_inference.sh` file. + Currently it's at line 35. +``` + checkpoint_path="path/to/checkpoint/model.safetensors" + output_dir="path/to/output/directory" + source_path="path/to/source/audio.wav" + reference_path="path/to/reference/audio.wav" +``` +2. Start inference: + ```bash + bash path/to/Amphion/egs/vc/noro_inference.sh + ``` + +3. You got the reconstructed mel spectrum saved to the output direction. + Then use the [BigVGAN](https://github.com/NVIDIA/BigVGAN) to construct the wav file. + +## Training from Scratch + +### Data Preparation + +We use the LibriLight dataset for training and evaluation. You can download it using the following commands: +```bash + wget https://dl.fbaipublicfiles.com/librilight/data/large.tar + wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar + wget https://dl.fbaipublicfiles.com/librilight/data/small.tar +``` + +### Training the Model with Clean Reference Voice + +Configure training parameters: +Our configuration file for training clean Noro model is at "egs/vc/Noro/exp_config_clean.json", and Nosiy Noro model at "egs/vc/Noro/exp_config_noisy.json". + +To train your model, you need to modify the `dataset` variable in the json configurations. +Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory. +``` + "directory_list": [ + "path/to/your/training_data_directory1", + "path/to/your/training_data_directory2", + "path/to/your/training_data_directory3" + ], +``` + +If you want to train for the noisy noro model, you also need to set the direction path for the noisy data at "egs/vc/Noro/exp_config_noisy.json". +``` + "noise_dir": "path/to/your/noise/train/directory", + "test_noise_dir": "path/to/your/noise/test/directory" +``` + +You can change other experiment settings in the config flies such as the learning rate, optimizer and the dataset. + + **Set smaller batch_size if you are out of memory😢😢** + +I used max_tokens = 3200000 to successfully run on a single card, if you'r out of memory, try smaller. + +```json + "max_tokens": 3200000 +``` +### Resume from existing checkpoint +Our framework supports resuming from existing checkpoint. +If this is a new experiment, use the following command: +``` +CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \ +"${work_dir}/bins/vc/train.py" \ + --config $exp_config \ + --exp_name $exp_name \ + --log_level debug +``` +To resume training or fine-tune from a checkpoint, use the following command: +Ensure the options `--resume`, `--resume_type resume`, and `--checkpoint_path` are set. + +### Run the command to Train model +Start clean training: + ```bash + bash path/to/Amphion/egs/vc/noro_train_clean.sh + ``` + + +Start noisy training: + ```bash + bash path/to/Amphion/egs/vc/noro_train_noisy.sh + ``` + + + diff --git a/egs/vc/Noro/exp_config_base.json b/egs/vc/Noro/exp_config_base.json new file mode 100644 index 00000000..27832d4d --- /dev/null +++ b/egs/vc/Noro/exp_config_base.json @@ -0,0 +1,61 @@ +{ + "base_config": "config/noro.json", + "model_type": "VC", + "dataset": [ + "mls" + ], + "sample_rate": 16000, + "n_fft": 1024, + "n_mel": 80, + "hop_size": 200, + "win_size": 800, + "fmin": 0, + "fmax": 8000, + "preprocess": { + "kmeans_model_path": "path/to/kmeans_model", + "hubert_model_path": "path/to/hubert_model", + "sample_rate": 16000, + "hop_size": 200, + "f0_min": 50, + "f0_max": 500, + "frame_period": 12.5 + }, + "model": { + "reference_encoder": { + "encoder_layer": 6, + "encoder_hidden": 512, + "encoder_head": 8, + "conv_filter_size": 2048, + "conv_kernel_size": 9, + "encoder_dropout": 0.2, + "use_skip_connection": false, + "use_new_ffn": true, + "ref_in_dim": 80, + "ref_out_dim": 512, + "use_query_emb": true, + "num_query_emb": 32 + }, + "diffusion": { + "beta_min": 0.05, + "beta_max": 20, + "sigma": 1.0, + "noise_factor": 1.0, + "ode_solve_method": "euler", + "diff_model_type": "WaveNet", + "diff_wavenet":{ + "input_size": 80, + "hidden_size": 512, + "out_size": 80, + "num_layers": 47, + "cross_attn_per_layer": 3, + "dilation_cycle": 2, + "attn_head": 8, + "drop_out": 0.2 + } + }, + "vc_feature": { + "content_feature_dim": 768, + "hidden_dim": 512 + } + } +} \ No newline at end of file diff --git a/egs/vc/Noro/exp_config_clean.json b/egs/vc/Noro/exp_config_clean.json new file mode 100644 index 00000000..e0dbd367 --- /dev/null +++ b/egs/vc/Noro/exp_config_clean.json @@ -0,0 +1,38 @@ +{ + "base_config": "egs/vc/exp_config_base.json", + "dataset": [ + "mls" + ], + // Specify the output root path to save model checkpoints and logs + "log_dir": "path/to/your/checkpoint/directory", + "train": { + // New trainer and Accelerator + "gradient_accumulation_step": 1, + "tracker": ["tensorboard"], + "max_epoch": 10, + "save_checkpoint_stride": [1000], + "keep_last": [20], + "run_eval": [true], + "dataloader": { + "num_worker": 64, + "pin_memory": true + }, + "adam": { + "lr": 5e-5 + }, + "use_dynamic_batchsize": true, + "max_tokens": 3200000, + "max_sentences": 64, + "lr_warmup_steps": 5000, + "lr_scheduler": "cosine", + "num_train_steps": 800000 + }, + "trans_exp": { + "directory_list": [ + "path/to/your/training_data_directory1", + "path/to/your/training_data_directory2", + "path/to/your/training_data_directory3" + ], + "use_ref_noise": false + } +} \ No newline at end of file diff --git a/egs/vc/Noro/exp_config_noisy.json b/egs/vc/Noro/exp_config_noisy.json new file mode 100644 index 00000000..7e5d7e75 --- /dev/null +++ b/egs/vc/Noro/exp_config_noisy.json @@ -0,0 +1,40 @@ +{ + "base_config": "egs/vc/exp_config_base.json", + "dataset": [ + "mls" + ], + // Specify the output root path to save model checkpoints and logs + "log_dir": "path/to/your/checkpoint/directory", + "train": { + // New trainer and Accelerator + "gradient_accumulation_step": 1, + "tracker": ["tensorboard"], + "max_epoch": 10, + "save_checkpoint_stride": [1000], + "keep_last": [20], + "run_eval": [true], + "dataloader": { + "num_worker": 64, + "pin_memory": true + }, + "adam": { + "lr": 5e-5 + }, + "use_dynamic_batchsize": true, + "max_tokens": 3200000, + "max_sentences": 64, + "lr_warmup_steps": 5000, + "lr_scheduler": "cosine", + "num_train_steps": 800000 + }, + "trans_exp": { + "directory_list": [ + "path/to/your/training_data_directory1", + "path/to/your/training_data_directory2", + "path/to/your/training_data_directory3" + ], + "use_ref_noise": true, + "noise_dir": "path/to/your/noise/train/directory", + "test_noise_dir": "path/to/your/noise/test/directory" + } + } \ No newline at end of file diff --git a/egs/vc/Noro/noro_inference.sh b/egs/vc/Noro/noro_inference.sh new file mode 100644 index 00000000..5f10a9cd --- /dev/null +++ b/egs/vc/Noro/noro_inference.sh @@ -0,0 +1,54 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Set the PYTHONPATH to the current directory +export PYTHONPATH="./" + +######## Build Experiment Environment ########### +# Get the current directory of the script +exp_dir=$(cd `dirname $0`; pwd) +# Get the parent directory of the experiment directory +work_dir=$(dirname $(dirname $exp_dir)) + +# Export environment variables for the working directory and Python path +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +# Build the monotonic alignment module +cd $work_dir/modules/monotonic_align +mkdir -p monotonic_align +python setup.py build_ext --inplace +cd $work_dir + +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}/exp_config_base.json" +fi + +echo "Experimental Configuration File: $exp_config" + +cuda_id=0 + +# Set paths (modify these paths to your own) +checkpoint_path="path/to/checkpoint/model.safetensors" +output_dir="path/to/output/directory" +source_path="path/to/source/audio.wav" +reference_path="path/to/reference/audio.wav" + +echo "CUDA ID: $cuda_id" +echo "Checkpoint Path: $checkpoint_path" +echo "Output Directory: $output_dir" +echo "Source Audio Path: $source_path" +echo "Reference Audio Path: $reference_path" + +# Run the voice conversion inference script +python "${work_dir}/models/vc/noro_inference.py" \ + --config $exp_config \ + --checkpoint_path $checkpoint_path \ + --output_dir $output_dir \ + --cuda_id ${cuda_id} \ + --source_path $source_path \ + --ref_path $reference_path + diff --git a/egs/vc/Noro/noro_train_clean.sh b/egs/vc/Noro/noro_train_clean.sh new file mode 100644 index 00000000..e04c16a8 --- /dev/null +++ b/egs/vc/Noro/noro_train_clean.sh @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +######## Build Experiment Environment ########### +exp_dir=$(cd `dirname $0`; pwd) +work_dir=$(dirname $(dirname $exp_dir)) + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +cd $work_dir/modules/monotonic_align +mkdir -p monotonic_align +python setup.py build_ext --inplace +cd $work_dir + +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}/exp_config_clean.json" +fi +echo "Experimental Configuration File: $exp_config" + +# Set experiment name +exp_name="experiment_name" + +# Set CUDA ID +if [ -z "$gpu" ]; then + gpu="0,1,2,3" +fi + +######## Train Model ########### +echo "Experimental Name: $exp_name" + +# Specify the checkpoint folder (modify this path to your own) +checkpoint_path="path/to/checkpoint/noro_checkpoint" + + +# If this is a new experiment, use the following command: +# CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \ +# "${work_dir}/bins/vc/train.py" \ +# --config $exp_config \ +# --exp_name $exp_name \ +# --log_level debug + +# To resume training or fine-tune from a checkpoint, use the following command: +# Ensure the options --resume, --resume_type resume, and --checkpoint_path are set +CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \ +"${work_dir}/bins/vc/Noro/train.py" \ + --config $exp_config \ + --exp_name $exp_name \ + --log_level debug \ + --resume \ + --resume_type resume \ + --checkpoint_path $checkpoint_path diff --git a/egs/vc/Noro/noro_train_noisy.sh b/egs/vc/Noro/noro_train_noisy.sh new file mode 100644 index 00000000..0c910bf6 --- /dev/null +++ b/egs/vc/Noro/noro_train_noisy.sh @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +######## Build Experiment Environment ########### +exp_dir=$(cd `dirname $0`; pwd) +work_dir=$(dirname $(dirname $exp_dir)) + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +cd $work_dir/modules/monotonic_align +mkdir -p monotonic_align +python setup.py build_ext --inplace +cd $work_dir + +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}/exp_config_noisy.json" +fi +echo "Experimental Configuration File: $exp_config" + +# Set experiment name +exp_name="experiment_name" + +# Set CUDA ID +if [ -z "$gpu" ]; then + gpu="0,1,2,3" +fi + +######## Train Model ########### +echo "Experimental Name: $exp_name" + +# Specify the checkpoint folder (modify this path to your own) +checkpoint_path="path/to/checkpoint/noro_checkpoint" + + +# If this is a new experiment, use the following command: +# CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \ +# "${work_dir}/bins/vc/train.py" \ +# --config $exp_config \ +# --exp_name $exp_name \ +# --log_level debug + +# To resume training or fine-tune from a checkpoint, use the following command: +# Ensure the options --resume, --resume_type resume, and --checkpoint_path are set +CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \ +"${work_dir}/bins/vc/train.py" \ + --config $exp_config \ + --exp_name $exp_name \ + --log_level debug \ + --resume \ + --resume_type resume \ + --checkpoint_path $checkpoint_path \ No newline at end of file diff --git a/egs/vc/README.md b/egs/vc/README.md new file mode 100644 index 00000000..13c10477 --- /dev/null +++ b/egs/vc/README.md @@ -0,0 +1,20 @@ +# Amphion Singing Voice Cloning (VC) Recipe + +## Quick Start + +We provide a **[beginner recipe](Noro)** to demonstrate how to train a cutting edge SVC model. Specifically, it is an official implementation of the paper "NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities". + +## Supported Model Architectures + +Until now, Amphion has supported a noise-robust VC model with the following architecture: + +
+
+ +
+
+ +It has the following features: +1. Noise-Robust Voice Conversion: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments. +2. One-shot Voice Conversion: Achieves timbre conversion using only one reference speech sample. +3. Speaker Representation Learning: Explores the potential of the reference encoder as a self-supervised speaker encoder. diff --git a/imgs/vc/NoroVC.png b/imgs/vc/NoroVC.png new file mode 100644 index 00000000..3ad47750 Binary files /dev/null and b/imgs/vc/NoroVC.png differ diff --git a/models/tts/naturalspeech2/ns2_trainer.py b/models/tts/naturalspeech2/ns2_trainer.py index 63c4353e..8ede54e8 100644 --- a/models/tts/naturalspeech2/ns2_trainer.py +++ b/models/tts/naturalspeech2/ns2_trainer.py @@ -433,26 +433,18 @@ def _train_step(self, batch): total_loss += dur_loss train_losses["dur_loss"] = dur_loss - x0 = self.model.module.code_to_latent(code) - if self.cfg.model.diffusion.diffusion_type == "diffusion": - # diff loss x0 - diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask) - total_loss += diff_loss_x0 - train_losses["diff_loss_x0"] = diff_loss_x0 - - # diff loss noise - diff_loss_noise = diff_loss( - diff_out["noise_pred"], diff_out["noise"], mask=mask - ) - total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda - train_losses["diff_loss_noise"] = diff_loss_noise - - elif self.cfg.model.diffusion.diffusion_type == "flow": - # diff flow matching loss - flow_gt = diff_out["noise"] - x0 - diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask) - total_loss += diff_loss_flow - train_losses["diff_loss_flow"] = diff_loss_flow + x0 = self.model.module.code_to_latent(code) + # diff loss x0 + diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask) + total_loss += diff_loss_x0 + train_losses["diff_loss_x0"] = diff_loss_x0 + + # diff loss noise + diff_loss_noise = diff_loss( + diff_out["noise_pred"], diff_out["noise"], mask=mask + ) + total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda + train_losses["diff_loss_noise"] = diff_loss_noise # diff loss ce @@ -534,26 +526,17 @@ def _valid_step(self, batch): valid_losses["dur_loss"] = dur_loss x0 = self.model.module.code_to_latent(code) - if self.cfg.model.diffusion.diffusion_type == "diffusion": - # diff loss x0 - diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask) - total_loss += diff_loss_x0 - valid_losses["diff_loss_x0"] = diff_loss_x0 - - # diff loss noise - diff_loss_noise = diff_loss( - diff_out["noise_pred"], diff_out["noise"], mask=mask - ) - total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda - valid_losses["diff_loss_noise"] = diff_loss_noise - - elif self.cfg.model.diffusion.diffusion_type == "flow": - # diff flow matching loss - flow_gt = diff_out["noise"] - x0 - diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask) - total_loss += diff_loss_flow - valid_losses["diff_loss_flow"] = diff_loss_flow - + # diff loss x0 + diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask) + total_loss += diff_loss_x0 + valid_losses["diff_loss_x0"] = diff_loss_x0 + + # diff loss noise + diff_loss_noise = diff_loss( + diff_out["noise_pred"], diff_out["noise"], mask=mask + ) + total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda + valid_losses["diff_loss_noise"] = diff_loss_noise # diff loss ce # (nq, B, T); (nq, B, T, 1024) diff --git a/models/vc/Noro/noro_base_trainer.py b/models/vc/Noro/noro_base_trainer.py new file mode 100644 index 00000000..bc343ed1 --- /dev/null +++ b/models/vc/Noro/noro_base_trainer.py @@ -0,0 +1,280 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import torch +import time +from pathlib import Path +import torch +import accelerate +from accelerate.logging import get_logger +from models.base.new_trainer import BaseTrainer + + +class Noro_base_Trainer(BaseTrainer): + r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements + ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this + class, and implement ``_build_model``, ``_forward_step``. + """ + + def __init__(self, args=None, cfg=None): + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # init with accelerate + self._init_accelerator() + self.accelerator.wait_for_everyone() + + with self.accelerator.main_process_first(): + self.logger = get_logger(args.exp_name, log_level="INFO") + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # init counts + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check values + if self.accelerator.is_main_process: + self.__check_basic_configs() + # Set runtime configs + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.keep_last = [ + i if i > 0 else float("inf") for i in self.cfg.train.keep_last + ] + self.run_eval = self.cfg.train.run_eval + + # set random seed + with self.accelerator.main_process_first(): + # start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + # self.logger.debug( + # f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + # ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # setup data_loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # # save phone table to exp dir. Should be done before building model due to loading phone table in model + # if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon": + # self._save_phone_symbols_file_to_exp_path() + + # setup model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.debug(self.model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info( + f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M" + ) + + # optimizer & scheduler + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + self.optimizer = self._build_optimizer() + self.scheduler = self._build_scheduler() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # create criterion + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterion = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume or Finetune + with self.accelerator.main_process_first(): + self._check_resume() + + # accelerate prepare + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self._accelerator_prepare() + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # save config file path + self.config_save_path = os.path.join(self.exp_dir, "args.json") + self.device = self.accelerator.device + + if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training: + self.speakers = self._build_speaker_lut() + self.utt2spk_dict = self._build_utt2spk_dict() + + # Only for TTS tasks + self.task_type = "TTS" + self.logger.info("Task type: {}".format(self.task_type)) + + def _check_resume(self): + # if args.resume: + if self.args.resume or ( + self.cfg.model_type == "VALLE" and self.args.train_stage == 2 + ): + if self.cfg.model_type == "VALLE" and self.args.train_stage == 2: + self.args.resume_type = "finetune" + + self.logger.info("Resuming from checkpoint...") + self.ckpt_path = self._load_model( + self.checkpoint_dir, self.args.checkpoint_path, self.args.resume_type + ) + self.checkpoints_path = json.load( + open(os.path.join(self.ckpt_path, "ckpts.json"), "r") + ) + + + def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None or checkpoint_path == "": + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + # example path epoch-0000_step-0017000_loss-1.972191, 找step最大的 + checkpoint_path = max(ls, key=lambda x: int(x.split("_")[-2].split("-")[-1])) + + if self.accelerator.is_main_process: + self.logger.info("Load model from {}".format(checkpoint_path)) + print("Load model from {}".format(checkpoint_path)) + + if resume_type == "resume": + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + if isinstance(self.model, dict): + for idx, sub_model in enumerate(self.model.keys()): + try: + if idx == 0: + ckpt_name = "pytorch_model.bin" + else: + ckpt_name = "pytorch_model_{}.bin".format(idx) + + self.model[sub_model].load_state_dict( + torch.load(os.path.join(checkpoint_path, ckpt_name)) + ) + except: + if idx == 0: + ckpt_name = "model.safetensors" + else: + ckpt_name = "model_{}.safetensors".format(idx) + + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model[sub_model]), + os.path.join(checkpoint_path, ckpt_name), + ) + + self.model[sub_model].cuda(self.accelerator.device) + else: + try: + self.model.load_state_dict( + torch.load(os.path.join(checkpoint_path, "pytorch_model.bin")) + ) + if self.accelerator.is_main_process: + self.logger.info("Loaded 'pytorch_model.bin' for resume") + except: + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "model.safetensors"), + ) + if self.accelerator.is_main_process: + self.logger.info("Loaded 'model.safetensors' for resume") + self.model.cuda(self.accelerator.device) + if self.accelerator.is_main_process: + self.logger.info("Load model weights SUCCESS!") + elif resume_type == "finetune": + if isinstance(self.model, dict): + for idx, sub_model in enumerate(self.model.keys()): + try: + if idx == 0: + ckpt_name = "pytorch_model.bin" + else: + ckpt_name = "pytorch_model_{}.bin".format(idx) + + self.model[sub_model].load_state_dict( + torch.load(os.path.join(checkpoint_path, ckpt_name)) + ) + except: + if idx == 0: + ckpt_name = "model.safetensors" + else: + ckpt_name = "model_{}.safetensors".format(idx) + + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model[sub_model]), + os.path.join(checkpoint_path, ckpt_name), + ) + + self.model[sub_model].cuda(self.accelerator.device) + else: + try: + self.model.load_state_dict( + torch.load(os.path.join(checkpoint_path, "pytorch_model.bin")) + ) + if self.accelerator.is_main_process: + self.logger.info("Loaded 'pytorch_model.bin' for finetune") + except: + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "model.safetensors"), + ) + if self.accelerator.is_main_process: + self.logger.info("Loaded 'model.safetensors' for finetune") + self.model.cuda(self.accelerator.device) + if self.accelerator.is_main_process: + self.logger.info("Load model weights for finetune SUCCESS!") + else: + raise ValueError("Unsupported resume type: {}".format(resume_type)) + return checkpoint_path + + def _check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) diff --git a/models/vc/Noro/noro_dataset.py b/models/vc/Noro/noro_dataset.py new file mode 100644 index 00000000..8c858788 --- /dev/null +++ b/models/vc/Noro/noro_dataset.py @@ -0,0 +1,444 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import librosa +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from multiprocessing import Pool, Lock +import random +import torchaudio + + +NUM_WORKERS = 64 +lock = Lock() # 创建一个全局锁 +SAMPLE_RATE = 16000 + +def get_metadata(file_path): + metadata = torchaudio.info(file_path) + return file_path, metadata.num_frames + +def get_speaker(file_path): + speaker_id = file_path.split(os.sep)[-3] + if 'mls' in file_path: + speaker = 'mls_' + speaker_id + else: + speaker = 'libri_' + speaker_id + return file_path, speaker + +def safe_write_to_file(data, file_path, mode='w'): + try: + with lock, open(file_path, mode, encoding='utf-8') as f: + json.dump(data, f) + f.flush() + os.fsync(f.fileno()) + except IOError as e: + print(f"Error writing to {file_path}: {e}") + + +class VCDataset(Dataset): + def __init__(self, args, TRAIN_MODE=True): + print(f"Initializing VCDataset") + if TRAIN_MODE: + directory_list = args.directory_list + else: + directory_list = args.test_directory_list + random.shuffle(directory_list) + + self.use_ref_noise = args.use_ref_noise + print(f"use_ref_noise: {self.use_ref_noise}") + + # number of workers + print(f"Using {NUM_WORKERS} workers") + self.directory_list = directory_list + print(f"Loading {len(directory_list)} directories: {directory_list}") + + # Load metadata cache + # metadata_cache: {file_path: num_frames} + self.metadata_cache_path = '/mnt/data2/hehaorui/ckpt/rp_metadata_cache.json' + print(f"Loading metadata_cache from {self.metadata_cache_path}") + if os.path.exists(self.metadata_cache_path): + with open(self.metadata_cache_path, 'r', encoding='utf-8') as f: + self.metadata_cache = json.load(f) + print(f"Loaded {len(self.metadata_cache)} metadata_cache") + else: + print(f"metadata_cache not found, creating new") + self.metadata_cache = {} + + # Load speaker cache + # speaker_cache: {file_path: speaker} + self.speaker_cache_path = '/mnt/data2/hehaorui/ckpt/rp_file2speaker.json' + print(f"Loading speaker_cache from {self.speaker_cache_path}") + if os.path.exists(self.speaker_cache_path): + with open(self.speaker_cache_path, 'r', encoding='utf-8') as f: + self.speaker_cache = json.load(f) + print(f"Loaded {len(self.speaker_cache)} speaker_cache") + else: + print(f"speaker_cache not found, creating new") + self.speaker_cache = {} + + self.files = [] + # Load all flac files + for directory in directory_list: + print(f"Loading {directory}") + files = self.get_flac_files(directory) + random.shuffle(files) + print(f"Loaded {len(files)} files") + self.files.extend(files) + del files + print(f"Now {len(self.files)} files") + self.meta_data_cache = self.process_files() + self.speaker_cache = self.process_speakers() + temp_cache_path = self.metadata_cache_path.replace('.json', f'_{directory.split("/")[-1]}.json') + if not os.path.exists(temp_cache_path): + safe_write_to_file(self.meta_data_cache, temp_cache_path) + print(f"Saved metadata cache to {temp_cache_path}") + temp_cache_path = self.speaker_cache_path.replace('.json', f'_{directory.split("/")[-1]}.json') + if not os.path.exists(temp_cache_path): + safe_write_to_file(self.speaker_cache, temp_cache_path) + print(f"Saved speaker cache to {temp_cache_path}") + + print(f"Loaded {len(self.files)} files") + random.shuffle(self.files) # Shuffle the files. + + self.filtered_files, self.all_num_frames, index2numframes, index2speakerid = self.filter_files() + print(f"Loaded {len(self.filtered_files)} files") + + self.index2numframes = index2numframes + self.index2speaker = index2speakerid + self.speaker2id = self.create_speaker2id() + self.num_frame_sorted = np.array(sorted(self.all_num_frames)) + self.num_frame_indices = np.array( + sorted( + range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k] + ) + ) + del self.meta_data_cache, self.speaker_cache + + if self.use_ref_noise: + if TRAIN_MODE: + self.noise_filenames = self.get_all_flac(args.noise_dir) + else: + self.noise_filenames = self.get_all_flac(args.test_noise_dir) + + def process_files(self): + print(f"Processing metadata...") + files_to_process = [file for file in self.files if file not in self.metadata_cache] + if files_to_process: + with Pool(processes=NUM_WORKERS) as pool: + results = list(tqdm(pool.imap_unordered(get_metadata, files_to_process), total=len(files_to_process))) + for file, num_frames in results: + self.metadata_cache[file] = num_frames + safe_write_to_file(self.metadata_cache, self.metadata_cache_path) + else: + print(f"Skipping processing metadata, loaded {len(self.metadata_cache)} files") + return self.metadata_cache + + def process_speakers(self): + print(f"Processing speakers...") + files_to_process = [file for file in self.files if file not in self.speaker_cache] + if files_to_process: + with Pool(processes=NUM_WORKERS) as pool: + results = list(tqdm(pool.imap_unordered(get_speaker, files_to_process), total=len(files_to_process))) + for file, speaker in results: + self.speaker_cache[file] = speaker + safe_write_to_file(self.speaker_cache, self.speaker_cache_path) + else: + print(f"Skipping processing speakers, loaded {len(self.speaker_cache)} files") + return self.speaker_cache + + def get_flac_files(self, directory): + flac_files = [] + for root, dirs, files in os.walk(directory): + for file in files: + # flac or wav + if file.endswith(".flac") or file.endswith(".wav"): + flac_files.append(os.path.join(root, file)) + return flac_files + + def get_all_flac(self, directory): + directories = [os.path.join(directory, d) for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] + if not directories: + return self.get_flac_files(directory) + with Pool(processes=NUM_WORKERS) as pool: + results = [] + for result in tqdm(pool.imap_unordered(self.get_flac_files, directories), total=len(directories), desc="Processing"): + results.extend(result) + print(f"Found {len(results)} waveform files") + return results + + def get_num_frames(self, index): + return self.index2numframes[index] + + def filter_files(self): + # Filter files + metadata_cache = self.meta_data_cache + speaker_cache = self.speaker_cache + filtered_files = [] + all_num_frames = [] + index2numframes = {} + index2speaker = {} + for file in self.files: + num_frames = metadata_cache[file] + if SAMPLE_RATE * 3 <= num_frames <= SAMPLE_RATE * 30: + filtered_files.append(file) + all_num_frames.append(num_frames) + index2speaker[len(filtered_files) - 1] = speaker_cache[file] + index2numframes[len(filtered_files) - 1] = num_frames + return filtered_files, all_num_frames, index2numframes, index2speaker + + def create_speaker2id(self): + speaker2id = {} + unique_id = 0 + print(f"Creating speaker2id from {len(self.index2speaker)} utterences") + for _, speaker in tqdm(self.index2speaker.items()): + if speaker not in speaker2id: + speaker2id[speaker] = unique_id + unique_id += 1 + print(f"Created speaker2id with {len(speaker2id)} speakers") + return speaker2id + + def snr_mixer(self, clean, noise, snr): + # Normalizing to -25 dB FS + rmsclean = (clean**2).mean()**0.5 + epsilon = 1e-10 + rmsclean = max(rmsclean, epsilon) + scalarclean = 10 ** (-25 / 20) / rmsclean + clean = clean * scalarclean + + rmsnoise = (noise**2).mean()**0.5 + scalarnoise = 10 ** (-25 / 20) /rmsnoise + noise = noise * scalarnoise + rmsnoise = (noise**2).mean()**0.5 + + # Set the noise level for a given SNR + noisescalar = np.sqrt(rmsclean / (10**(snr/20)) / rmsnoise) + noisenewlevel = noise * noisescalar + noisyspeech = clean + noisenewlevel + noisyspeech_tensor = torch.tensor(noisyspeech, dtype=torch.float32) + return noisyspeech_tensor + + def add_noise(self, clean): + # self.noise_filenames: list of noise files + random_idx = np.random.randint(0, np.size(self.noise_filenames)) + noise, _ = librosa.load(self.noise_filenames[random_idx], sr=SAMPLE_RATE) + clean = clean.cpu().numpy() + if len(noise)>=len(clean): + noise = noise[0:len(clean)] + else: + while len(noise)<=len(clean): + random_idx = (random_idx + 1)%len(self.noise_filenames) + newnoise, fs = librosa.load(self.noise_filenames[random_idx], sr=SAMPLE_RATE) + noiseconcat = np.append(noise, np.zeros(int(fs * 0.2))) + noise = np.append(noiseconcat, newnoise) + noise = noise[0:len(clean)] + snr = random.uniform(0.0,20.0) + noisyspeech = self.snr_mixer(clean=clean, noise=noise, snr=snr) + del noise + return noisyspeech + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + file_path = self.filtered_files[idx] + speech, _ = librosa.load(file_path, sr=SAMPLE_RATE) + if len(speech) > 30 * SAMPLE_RATE: + speech = speech[:30 * SAMPLE_RATE] + speech = torch.tensor(speech, dtype=torch.float32) + # inputs = self._get_reference_vc(speech, hop_length=320) + inputs = self._get_reference_vc(speech, hop_length=200) + speaker = self.index2speaker[idx] + speaker_id = self.speaker2id[speaker] + inputs["speaker_id"] = speaker_id + return inputs + + def _get_reference_vc(self, speech, hop_length): + pad_size = 1600 - speech.shape[0] % 1600 + speech = torch.nn.functional.pad(speech, (0, pad_size)) + + #hop_size + frame_nums = speech.shape[0] // hop_length + clip_frame_nums = np.random.randint(int(frame_nums * 0.25), int(frame_nums * 0.45)) + clip_frame_nums += (frame_nums - clip_frame_nums) % 8 + start_frames, end_frames = 0, clip_frame_nums + + ref_speech = speech[start_frames * hop_length : end_frames * hop_length] + new_speech = torch.cat((speech[:start_frames * hop_length], speech[end_frames * hop_length:]), 0) + + ref_mask = torch.ones(len(ref_speech) // hop_length) + mask = torch.ones(len(new_speech) // hop_length) + if not self.use_ref_noise: + # not use noise + return {"speech": new_speech, "ref_speech": ref_speech, "ref_mask": ref_mask, "mask": mask} + else: + # use reference noise + noisy_ref_speech = self.add_noise(ref_speech) + return {"speech": new_speech, "ref_speech": ref_speech, "noisy_ref_speech": noisy_ref_speech, "ref_mask": ref_mask, "mask": mask} + + +class BaseCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, T, n_mels] + # frame_pitch, frame_energy: [1, T] + # target_len: [1] + # spk_id: [b, 1] + # mask: [b, T, 1] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "phone_len": + packed_batch_features["phone_len"] = torch.LongTensor( + [b["phone_len"] for b in batch] + ) + masks = [ + torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["phn_mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "audio_len": + packed_batch_features["audio_len"] = torch.LongTensor( + [b["audio_len"] for b in batch] + ) + masks = [ + torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch + ] + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + return packed_batch_features + +class VCCollator(BaseCollator): + def __init__(self, cfg): + BaseCollator.__init__(self, cfg) + #self.use_noise = cfg.trans_exp.use_noise + + self.use_ref_noise = self.cfg.trans_exp.use_ref_noise + print(f"use_ref_noise: {self.use_ref_noise}") + + + def __call__(self, batch): + packed_batch_features = dict() + + # Function to handle tensor copying + def process_tensor(data, dtype=torch.float32): + if isinstance(data, torch.Tensor): + return data.clone().detach() + else: + return torch.tensor(data, dtype=dtype) + + # Process 'speech' data + speeches = [process_tensor(b['speech']) for b in batch] + packed_batch_features['speech'] = pad_sequence(speeches, batch_first=True, padding_value=0) + + # Process 'ref_speech' data + ref_speeches = [process_tensor(b['ref_speech']) for b in batch] + packed_batch_features['ref_speech'] = pad_sequence(ref_speeches, batch_first=True, padding_value=0) + + # Process 'mask' data + masks = [process_tensor(b['mask']) for b in batch] + packed_batch_features['mask'] = pad_sequence(masks, batch_first=True, padding_value=0) + + # Process 'ref_mask' data + ref_masks = [process_tensor(b['ref_mask']) for b in batch] + packed_batch_features['ref_mask'] = pad_sequence(ref_masks, batch_first=True, padding_value=0) + + # Process 'speaker_id' data + speaker_ids = [process_tensor(b['speaker_id'], dtype=torch.int64) for b in batch] + packed_batch_features['speaker_id'] = torch.stack(speaker_ids, dim=0) + if self.use_ref_noise: + # Process 'noisy_ref_speech' data + noisy_ref_speeches = [process_tensor(b['noisy_ref_speech']) for b in batch] + packed_batch_features['noisy_ref_speech'] = pad_sequence(noisy_ref_speeches, batch_first=True, padding_value=0) + return packed_batch_features + + + +def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + +def batch_by_size( + indices, + num_tokens_fn, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + """ + bsz_mult = required_batch_size_multiple + + sample_len = 0 + sample_lens = [] + batch = [] + batches = [] + for i in range(len(indices)): + idx = indices[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert ( + sample_len <= max_tokens + ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( + idx, sample_len, max_tokens + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches diff --git a/models/vc/Noro/noro_inference.py b/models/vc/Noro/noro_inference.py new file mode 100644 index 00000000..c4eab1e2 --- /dev/null +++ b/models/vc/Noro/noro_inference.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import torch +import numpy as np +import librosa +from safetensors.torch import load_model +import os +from utils.util import load_config +from models.vc.Noro.noro_trainer import NoroTrainer +from models.vc.Noro.noro_model import Noro_VCmodel +from processors.content_extractor import HubertExtractor +from utils.mel import mel_spectrogram_torch +from utils.f0 import get_f0_features_using_dio, interpolate + +def build_trainer(args, cfg): + supported_trainer = { + "VC": NoroTrainer, + } + trainer_class = supported_trainer[cfg.model_type] + trainer = trainer_class(args, cfg) + return trainer + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="config.json", + help="JSON file for configurations.", + required=True, + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Checkpoint for resume training or fine-tuning.", + required=True, + ) + parser.add_argument( + "--output_dir", + help="Output path", + required=True, + ) + parser.add_argument( + "--ref_path", + type=str, + help="Reference voice path", + ) + parser.add_argument( + "--source_path", + type=str, + help="Source voice path", + ) + parser.add_argument( + "--cuda_id", + type=int, + default=0, + help="CUDA id for training." + ) + + parser.add_argument("--local_rank", default=-1, type=int) + args = parser.parse_args() + cfg = load_config(args.config) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + cuda_id = args.cuda_id + args.local_rank = torch.device(f"cuda:{cuda_id}") + print("Local rank:", args.local_rank) + + args.content_extractor = "mhubert" + + with torch.cuda.device(args.local_rank): + torch.cuda.empty_cache() + ckpt_path = args.checkpoint_path + + w2v = HubertExtractor(cfg) + w2v = w2v.to(device=args.local_rank) + w2v.eval() + + model = Noro_VCmodel(cfg=cfg.model) + print("Loading model") + + load_model(model, ckpt_path) + print("Model loaded") + model.cuda(args.local_rank) + model.eval() + + wav_path = args.source_path + ref_wav_path = args.ref_path + + wav, _ = librosa.load(wav_path, sr=16000) + wav = np.pad(wav, (0, 1600 - len(wav) % 1600)) + audio = torch.from_numpy(wav).to(args.local_rank) + audio = audio[None, :] + + ref_wav, _ = librosa.load(ref_wav_path, sr=16000) + ref_wav = np.pad(ref_wav, (0, 200 - len(ref_wav) % 200)) + ref_audio = torch.from_numpy(ref_wav).to(args.local_rank) + ref_audio = ref_audio[None, :] + + with torch.no_grad(): + ref_mel = mel_spectrogram_torch(ref_audio, cfg) + ref_mel = ref_mel.transpose(1, 2).to(device=args.local_rank) + ref_mask = torch.ones(ref_mel.shape[0], ref_mel.shape[1]).to(args.local_rank).bool() + + _, content_feature = w2v.extract_content_features(audio) + content_feature = content_feature.to(device=args.local_rank) + + wav = audio.cpu().numpy() + pitch_raw = get_f0_features_using_dio(wav, cfg) + pitch_raw, _ = interpolate(pitch_raw) + frame_num = len(wav) // cfg.preprocess.hop_size + pitch_raw = torch.from_numpy(pitch_raw[:frame_num]).float() + pitch = (pitch_raw - pitch_raw.mean(dim=1, keepdim=True)) / ( + pitch_raw.std(dim=1, keepdim=True) + 1e-6 + ) + + x0 = model.inference( + content_feature=content_feature, + pitch=pitch, + x_ref=ref_mel, + x_ref_mask=ref_mask, + inference_steps=200, + sigma=1.2, + ) # 150-300 0.95-1.5 + + recon_path = f"{args.output_dir}/recon_mel.npy" + np.save(recon_path, x0.transpose(1, 2).detach().cpu().numpy()) + print(f"Mel spectrogram saved to: {recon_path}") + +if __name__ == "__main__": + main() + diff --git a/models/vc/Noro/noro_loss.py b/models/vc/Noro/noro_loss.py new file mode 100644 index 00000000..257c47e8 --- /dev/null +++ b/models/vc/Noro/noro_loss.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def cross_entropy_loss(preds, targets, reduction='none'): + log_softmax = nn.LogSoftmax(dim=-1) + loss = (-targets * log_softmax(preds)).sum(1) + if reduction == "none": + return loss + elif reduction == "mean": + return loss.mean() + +class ConstractiveSpeakerLoss(nn.Module): + def __init__(self, temperature=1.): + super(ConstractiveSpeakerLoss, self).__init__() + self.temperature = temperature + + def forward(self, x, speaker_ids): + # x : B, H + # speaker_ids: B 3 4 3 + speaker_ids = speaker_ids.reshape(-1) + speaker_ids_expand = torch.zeros(len(speaker_ids),len(speaker_ids)).to(speaker_ids.device) + speaker_ids_expand = (speaker_ids.view(-1,1) == speaker_ids).float() + x_t = x.transpose(0,1) # B, C --> C,B + logits = (x @ x_t) / self.temperature # B, H * H, B --> B, B + targets = F.softmax(speaker_ids_expand / self.temperature, dim=-1) + loss = cross_entropy_loss(logits, targets, reduction='none') + return loss.mean() + +def diff_loss(pred, target, mask, loss_type="l1"): + # pred: (B, T, d) + # target: (B, T, d) + # mask: (B, T) + if loss_type == "l1": + loss = F.l1_loss(pred, target, reduction="none").float() * ( + mask.to(pred.dtype).unsqueeze(-1) + ) + elif loss_type == "l2": + loss = F.mse_loss(pred, target, reduction="none").float() * ( + mask.to(pred.dtype).unsqueeze(-1) + ) + else: + raise NotImplementedError() + loss = (torch.mean(loss, dim=-1)).sum() / (mask.to(pred.dtype).sum()) + return loss \ No newline at end of file diff --git a/models/vc/Noro/noro_model.py b/models/vc/Noro/noro_model.py new file mode 100644 index 00000000..0b0bbbf4 --- /dev/null +++ b/models/vc/Noro/noro_model.py @@ -0,0 +1,1347 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import math +import json5 +from librosa.filters import mel as librosa_mel_fn +from einops.layers.torch import Rearrange + +sr = 16000 +MAX_WAV_VALUE = 32768.0 + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} +init_mel_and_hann = False + + +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window, init_mel_and_hann + if not init_mel_and_hann: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + init_mel_and_hann = True + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +class Diffusion(nn.Module): + def __init__(self, cfg, diff_model): + super().__init__() + + self.cfg = cfg + self.diff_estimator = diff_model + self.beta_min = cfg.beta_min + self.beta_max = cfg.beta_max + self.sigma = cfg.sigma + self.noise_factor = cfg.noise_factor + + def forward(self, x, condition_embedding, x_mask, reference_embedding, offset=1e-5): + diffusion_step = torch.rand( + x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False + ) + diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset) + xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step) + + cum_beta = self.get_cum_beta(diffusion_step.unsqueeze(-1).unsqueeze(-1)) + x0_pred = self.diff_estimator( + xt, condition_embedding, x_mask, reference_embedding, diffusion_step + ) + mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2)) + variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2))) + noise_pred = (xt - mean_pred) / (torch.sqrt(variance) * self.noise_factor) + noise = z + diff_out = {"x0_pred": x0_pred, "noise_pred": noise_pred, "noise": noise} + return diff_out + + @torch.no_grad() + def get_cum_beta(self, time_step): + return self.beta_min * time_step + 0.5 * (self.beta_max - self.beta_min) * ( + time_step**2 + ) + + @torch.no_grad() + def get_beta_t(self, time_step): + return self.beta_min + (self.beta_max - self.beta_min) * time_step + + @torch.no_grad() + def forward_diffusion(self, x0, diffusion_step): + time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1) + cum_beta = self.get_cum_beta(time_step) + mean = x0 * torch.exp(-0.5 * cum_beta / (self.sigma**2)) + variance = (self.sigma**2) * (1 - torch.exp(-cum_beta / (self.sigma**2))) + z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) + xt = mean + z * torch.sqrt(variance) * self.noise_factor + return xt, z + + @torch.no_grad() + def cal_dxt( + self, xt, condition_embedding, x_mask, reference_embedding, diffusion_step, h + ): + time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1) + cum_beta = self.get_cum_beta(time_step=time_step) + beta_t = self.get_beta_t(time_step=time_step) + x0_pred = self.diff_estimator( + xt, condition_embedding, x_mask, reference_embedding, diffusion_step + ) + mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2)) + noise_pred = xt - mean_pred + variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2))) + logp = -noise_pred / (variance + 1e-8) + dxt = -0.5 * h * beta_t * (logp + xt / (self.sigma**2)) + return dxt + + @torch.no_grad() + def reverse_diffusion( + self, z, condition_embedding, x_mask, reference_embedding, n_timesteps + ): + h = 1.0 / max(n_timesteps, 1) + xt = z + for i in range(n_timesteps): + t = (1.0 - (i + 0.5) * h) * torch.ones( + z.shape[0], dtype=z.dtype, device=z.device + ) + dxt = self.cal_dxt( + xt, + condition_embedding, + x_mask, + reference_embedding, + diffusion_step=t, + h=h, + ) + xt_ = xt - dxt + if self.cfg.ode_solve_method == "midpoint": + x_mid = 0.5 * (xt_ + xt) + dxt = self.cal_dxt( + x_mid, + condition_embedding, + x_mask, + reference_embedding, + diffusion_step=t + 0.5 * h, + h=h, + ) + xt = xt - dxt + elif self.cfg.ode_solve_method == "euler": + xt = xt_ + return xt + + @torch.no_grad() + def reverse_diffusion_from_t( + self, z, condition_embedding, x_mask, reference_embedding, n_timesteps, t_start + ): + h = t_start / max(n_timesteps, 1) + xt = z + for i in range(n_timesteps): + t = (t_start - (i + 0.5) * h) * torch.ones( + z.shape[0], dtype=z.dtype, device=z.device + ) + dxt = self.cal_dxt( + xt, + x_mask, + condition_embedding, + x_mask, + reference_embedding, + diffusion_step=t, + h=h, + ) + xt_ = xt - dxt + if self.cfg.ode_solve_method == "midpoint": + x_mid = 0.5 * (xt_ + xt) + dxt = self.cal_dxt( + x_mid, + condition_embedding, + x_mask, + reference_embedding, + diffusion_step=t + 0.5 * h, + h=h, + ) + xt = xt - dxt + elif self.cfg.ode_solve_method == "euler": + xt = xt_ + return xt + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Linear2(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.linear_1 = nn.Linear(dim, dim * 2) + self.linear_2 = nn.Linear(dim * 2, dim) + self.linear_1.weight.data.normal_(0.0, 0.02) + self.linear_2.weight.data.normal_(0.0, 0.02) + + def forward(self, x): + x = self.linear_1(x) + x = F.silu(x) + x = self.linear_2(x) + return x + + +class StyleAdaptiveLayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-5): + super().__init__() + self.in_dim = normalized_shape + self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False) + self.style = nn.Linear(self.in_dim, self.in_dim * 2) + self.style.bias.data[: self.in_dim] = 1 + self.style.bias.data[self.in_dim :] = 0 + + def forward(self, x, condition): + # x: (B, T, d); condition: (B, T, d) + + style = self.style(torch.mean(condition, dim=1, keepdim=True)) + + gamma, beta = style.chunk(2, -1) + + out = self.norm(x) + + out = gamma * out + beta + return out + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout, max_len=5000): + super().__init__() + + self.dropout = dropout + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[: x.size(0)] + return F.dropout(x, self.dropout, training=self.training) + + +class TransformerFFNLayer(nn.Module): + def __init__( + self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout + ): + super().__init__() + + self.encoder_hidden = encoder_hidden + self.conv_filter_size = conv_filter_size + self.conv_kernel_size = conv_kernel_size + self.encoder_dropout = encoder_dropout + + self.ffn_1 = nn.Conv1d( + self.encoder_hidden, + self.conv_filter_size, + self.conv_kernel_size, + padding=self.conv_kernel_size // 2, + ) + self.ffn_1.weight.data.normal_(0.0, 0.02) + self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) + self.ffn_2.weight.data.normal_(0.0, 0.02) + + def forward(self, x): + # x: (B, T, d) + x = self.ffn_1(x.permute(0, 2, 1)).permute( + 0, 2, 1 + ) # (B, T, d) -> (B, d, T) -> (B, T, d) + x = F.silu(x) + x = F.dropout(x, self.encoder_dropout, training=self.training) + x = self.ffn_2(x) + return x + + +class TransformerFFNLayerOld(nn.Module): + def __init__( + self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout + ): + super().__init__() + + self.encoder_hidden = encoder_hidden + self.conv_filter_size = conv_filter_size + self.conv_kernel_size = conv_kernel_size + self.encoder_dropout = encoder_dropout + + self.ffn_1 = nn.Linear(self.encoder_hidden, self.conv_filter_size) + self.ffn_1.weight.data.normal_(0.0, 0.02) + self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) + self.ffn_2.weight.data.normal_(0.0, 0.02) + + def forward(self, x): + x = self.ffn_1(x) + x = F.silu(x) + x = F.dropout(x, self.encoder_dropout, training=self.training) + x = self.ffn_2(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + encoder_hidden, + encoder_head, + conv_filter_size, + conv_kernel_size, + encoder_dropout, + use_cln, + use_skip_connection, + use_new_ffn, + add_diff_step, + ): + super().__init__() + self.encoder_hidden = encoder_hidden + self.encoder_head = encoder_head + self.conv_filter_size = conv_filter_size + self.conv_kernel_size = conv_kernel_size + self.encoder_dropout = encoder_dropout + self.use_cln = use_cln + self.use_skip_connection = use_skip_connection + self.use_new_ffn = use_new_ffn + self.add_diff_step = add_diff_step + + if not self.use_cln: + self.ln_1 = nn.LayerNorm(self.encoder_hidden) + self.ln_2 = nn.LayerNorm(self.encoder_hidden) + else: + self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden) + self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden) + + self.self_attn = nn.MultiheadAttention( + self.encoder_hidden, self.encoder_head, batch_first=True + ) + + if self.use_new_ffn: + self.ffn = TransformerFFNLayer( + self.encoder_hidden, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + ) + else: + self.ffn = TransformerFFNLayerOld( + self.encoder_hidden, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + ) + + if self.use_skip_connection: + self.skip_linear = nn.Linear(self.encoder_hidden * 2, self.encoder_hidden) + self.skip_linear.weight.data.normal_(0.0, 0.02) + self.skip_layernorm = nn.LayerNorm(self.encoder_hidden) + + if self.add_diff_step: + self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden) + # self.diff_step_projection = nn.linear(self.encoder_hidden, self.encoder_hidden) + # self.encoder_hidden.weight.data.normal_(0.0, 0.02) + self.diff_step_projection = Linear2(self.encoder_hidden) + + def forward( + self, x, key_padding_mask, conditon=None, skip_res=None, diffusion_step=None + ): + # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d); skip_res: (B, T, d); diffusion_step: (B,) + + if self.use_skip_connection and skip_res != None: + x = torch.cat([x, skip_res], dim=-1) # (B, T, 2*d) + x = self.skip_linear(x) + x = self.skip_layernorm(x) + + if self.add_diff_step and diffusion_step != None: + diff_step_embedding = self.diff_step_emb(diffusion_step) + diff_step_embedding = self.diff_step_projection(diff_step_embedding) + x = x + diff_step_embedding.unsqueeze(1) + + residual = x + + # pre norm + if self.use_cln: + x = self.ln_1(x, conditon) + else: + x = self.ln_1(x) + + # self attention + if key_padding_mask != None: + key_padding_mask_input = ~(key_padding_mask.bool()) + else: + key_padding_mask_input = None + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=key_padding_mask_input + ) + x = F.dropout(x, self.encoder_dropout, training=self.training) + + x = residual + x + + # pre norm + residual = x + if self.use_cln: + x = self.ln_2(x, conditon) + else: + x = self.ln_2(x) + + # ffn + x = self.ffn(x) + + x = residual + x + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + enc_emb_tokens=None, + encoder_layer=None, + encoder_hidden=None, + encoder_head=None, + conv_filter_size=None, + conv_kernel_size=None, + encoder_dropout=None, + use_cln=None, + use_skip_connection=None, + use_new_ffn=None, + add_diff_step=None, + cfg=None, + ): + super().__init__() + + self.encoder_layer = ( + encoder_layer if encoder_layer is not None else cfg.encoder_layer + ) + self.encoder_hidden = ( + encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden + ) + self.encoder_head = ( + encoder_head if encoder_head is not None else cfg.encoder_head + ) + self.conv_filter_size = ( + conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size + ) + self.conv_kernel_size = ( + conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size + ) + self.encoder_dropout = ( + encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout + ) + self.use_cln = use_cln if use_cln is not None else cfg.use_cln + self.use_skip_connection = ( + use_skip_connection + if use_skip_connection is not None + else cfg.use_skip_connection + ) + self.add_diff_step = ( + add_diff_step if add_diff_step is not None else cfg.add_diff_step + ) + self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn + + if enc_emb_tokens != None: + self.use_enc_emb = True + self.enc_emb_tokens = enc_emb_tokens + else: + self.use_enc_emb = False + + self.position_emb = PositionalEncoding( + self.encoder_hidden, self.encoder_dropout + ) + + self.layers = nn.ModuleList([]) + if self.use_skip_connection: + self.layers.extend( + [ + TransformerEncoderLayer( + self.encoder_hidden, + self.encoder_head, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + self.use_cln, + use_skip_connection=False, + use_new_ffn=self.use_new_ffn, + add_diff_step=self.add_diff_step, + ) + for i in range( + (self.encoder_layer + 1) // 2 + ) # for example: 12 -> 6; 13 -> 7 + ] + ) + self.layers.extend( + [ + TransformerEncoderLayer( + self.encoder_hidden, + self.encoder_head, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + self.use_cln, + use_skip_connection=True, + use_new_ffn=self.use_new_ffn, + add_diff_step=self.add_diff_step, + ) + for i in range( + self.encoder_layer - (self.encoder_layer + 1) // 2 + ) # 12 -> 6; 13 -> 6 + ] + ) + else: + self.layers.extend( + [ + TransformerEncoderLayer( + self.encoder_hidden, + self.encoder_head, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + self.use_cln, + use_new_ffn=self.use_new_ffn, + add_diff_step=self.add_diff_step, + use_skip_connection=False, + ) + for i in range(self.encoder_layer) + ] + ) + + if self.use_cln: + self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden) + else: + self.last_ln = nn.LayerNorm(self.encoder_hidden) + + if self.add_diff_step: + self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden) + # self.diff_step_projection = nn.linear(self.encoder_hidden, self.encoder_hidden) + # self.encoder_hidden.weight.data.normal_(0.0, 0.02) + self.diff_step_projection = Linear2(self.encoder_hidden) + + def forward(self, x, key_padding_mask, condition=None, diffusion_step=None): + if len(x.shape) == 2 and self.use_enc_emb: + x = self.enc_emb_tokens(x) + x = self.position_emb(x) + else: + x = self.position_emb(x) # (B, T, d) + + if self.add_diff_step and diffusion_step != None: + diff_step_embedding = self.diff_step_emb(diffusion_step) + diff_step_embedding = self.diff_step_projection(diff_step_embedding) + x = x + diff_step_embedding.unsqueeze(1) + + if self.use_skip_connection: + skip_res_list = [] + # down + for layer in self.layers[: self.encoder_layer // 2]: + x = layer(x, key_padding_mask, condition) + res = x + skip_res_list.append(res) + # middle + for layer in self.layers[ + self.encoder_layer // 2 : (self.encoder_layer + 1) // 2 + ]: + x = layer(x, key_padding_mask, condition) + # up + for layer in self.layers[(self.encoder_layer + 1) // 2 :]: + skip_res = skip_res_list.pop() + x = layer(x, key_padding_mask, condition, skip_res) + else: + for layer in self.layers: + x = layer(x, key_padding_mask, condition) + + if self.use_cln: + x = self.last_ln(x, condition) + else: + x = self.last_ln(x) + + return x + + +class DiffTransformer(nn.Module): + def __init__( + self, + encoder_layer=None, + encoder_hidden=None, + encoder_head=None, + conv_filter_size=None, + conv_kernel_size=None, + encoder_dropout=None, + use_cln=None, + use_skip_connection=None, + use_new_ffn=None, + add_diff_step=None, + cat_diff_step=None, + in_dim=None, + out_dim=None, + cond_dim=None, + cfg=None, + ): + super().__init__() + + self.encoder_layer = ( + encoder_layer if encoder_layer is not None else cfg.encoder_layer + ) + self.encoder_hidden = ( + encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden + ) + self.encoder_head = ( + encoder_head if encoder_head is not None else cfg.encoder_head + ) + self.conv_filter_size = ( + conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size + ) + self.conv_kernel_size = ( + conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size + ) + self.encoder_dropout = ( + encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout + ) + self.use_cln = use_cln if use_cln is not None else cfg.use_cln + self.use_skip_connection = ( + use_skip_connection + if use_skip_connection is not None + else cfg.use_skip_connection + ) + self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn + self.add_diff_step = ( + add_diff_step if add_diff_step is not None else cfg.add_diff_step + ) + self.cat_diff_step = ( + cat_diff_step if cat_diff_step is not None else cfg.cat_diff_step + ) + self.in_dim = in_dim if in_dim is not None else cfg.in_dim + self.out_dim = out_dim if out_dim is not None else cfg.out_dim + self.cond_dim = cond_dim if cond_dim is not None else cfg.cond_dim + + if self.in_dim != self.encoder_hidden: + self.in_linear = nn.Linear(self.in_dim, self.encoder_hidden) + self.in_linear.weight.data.normal_(0.0, 0.02) + else: + self.in_dim = None + + if self.out_dim != self.encoder_hidden: + self.out_linear = nn.Linear(self.encoder_hidden, self.out_dim) + self.out_linear.weight.data.normal_(0.0, 0.02) + else: + self.out_dim = None + + assert not ((self.cat_diff_step == True) and (self.add_diff_step == True)) + + self.transformer_encoder = TransformerEncoder( + encoder_layer=self.encoder_layer, + encoder_hidden=self.encoder_hidden, + encoder_head=self.encoder_head, + conv_kernel_size=self.conv_kernel_size, + conv_filter_size=self.conv_filter_size, + encoder_dropout=self.encoder_dropout, + use_cln=self.use_cln, + use_skip_connection=self.use_skip_connection, + use_new_ffn=self.use_new_ffn, + add_diff_step=self.add_diff_step, + ) + + self.cond_project = nn.Linear(self.cond_dim, self.encoder_hidden) + self.cond_project.weight.data.normal_(0.0, 0.02) + self.cat_linear = nn.Linear(self.encoder_hidden * 2, self.encoder_hidden) + self.cat_linear.weight.data.normal_(0.0, 0.02) + + if self.cat_diff_step: + self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden) + self.diff_step_projection = Linear2(self.encoder_hidden) + + def forward( + self, + x, + condition_embedding, + key_padding_mask=None, + reference_embedding=None, + diffusion_step=None, + ): + # x: shape is (B, T, d_x) + # key_padding_mask: shape is (B, T), mask is 0 + # condition_embedding: from condition adapter, shape is (B, T, d_c) + # reference_embedding: from reference encoder, shape is (B, N, d_r), or (B, 1, d_r), or (B, d_r) + + if self.in_linear != None: + x = self.in_linear(x) + condition_embedding = self.cond_project(condition_embedding) + + x = torch.cat([x, condition_embedding], dim=-1) + x = self.cat_linear(x) + + if self.cat_diff_step and diffusion_step != None: + diff_step_embedding = self.diff_step_emb(diffusion_step) + diff_step_embedding = self.diff_step_projection( + diff_step_embedding + ).unsqueeze( + 1 + ) # (B, 1, d) + x = torch.cat([diff_step_embedding, x], dim=1) + if key_padding_mask != None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.ones(key_padding_mask.shape[0], 1).to( + key_padding_mask.device + ), + ], + dim=1, + ) + + x = self.transformer_encoder( + x, + key_padding_mask=key_padding_mask, + condition=reference_embedding, + diffusion_step=diffusion_step, + ) + + if self.cat_diff_step and diffusion_step != None: + x = x[:, 1:, :] + + if self.out_linear != None: + x = self.out_linear(x) + + return x + + +class ReferenceEncoder(nn.Module): + def __init__( + self, + encoder_layer=None, + encoder_hidden=None, + encoder_head=None, + conv_filter_size=None, + conv_kernel_size=None, + encoder_dropout=None, + use_skip_connection=None, + use_new_ffn=None, + ref_in_dim=None, + ref_out_dim=None, + use_query_emb=None, + num_query_emb=None, + cfg=None, + ): + super().__init__() + + self.encoder_layer = ( + encoder_layer if encoder_layer is not None else cfg.encoder_layer + ) + self.encoder_hidden = ( + encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden + ) + self.encoder_head = ( + encoder_head if encoder_head is not None else cfg.encoder_head + ) + self.conv_filter_size = ( + conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size + ) + self.conv_kernel_size = ( + conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size + ) + self.encoder_dropout = ( + encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout + ) + self.use_skip_connection = ( + use_skip_connection + if use_skip_connection is not None + else cfg.use_skip_connection + ) + self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn + self.in_dim = ref_in_dim if ref_in_dim is not None else cfg.ref_in_dim + self.out_dim = ref_out_dim if ref_out_dim is not None else cfg.ref_out_dim + self.use_query_emb = ( + use_query_emb if use_query_emb is not None else cfg.use_query_emb + ) + self.num_query_emb = ( + num_query_emb if num_query_emb is not None else cfg.num_query_emb + ) + + if self.in_dim != self.encoder_hidden: + self.in_linear = nn.Linear(self.in_dim, self.encoder_hidden) + self.in_linear.weight.data.normal_(0.0, 0.02) + else: + self.in_dim = None + + if self.out_dim != self.encoder_hidden: + self.out_linear = nn.Linear(self.encoder_hidden, self.out_dim) + self.out_linear.weight.data.normal_(0.0, 0.02) + else: + self.out_linear = None + + self.transformer_encoder = TransformerEncoder( + encoder_layer=self.encoder_layer, + encoder_hidden=self.encoder_hidden, + encoder_head=self.encoder_head, + conv_kernel_size=self.conv_kernel_size, + conv_filter_size=self.conv_filter_size, + encoder_dropout=self.encoder_dropout, + use_new_ffn=self.use_new_ffn, + use_cln=False, + use_skip_connection=False, + add_diff_step=False, + ) + + if self.use_query_emb: + # 32 x 512 + self.query_embs = nn.Embedding(self.num_query_emb, self.encoder_hidden) + self.query_attn = nn.MultiheadAttention( + self.encoder_hidden, self.encoder_hidden // 64, batch_first=True + ) + + def forward(self, x_ref, key_padding_mask=None): + # x_ref: (B, T, d_ref) + # key_padding_mask: (B, T) + # return speaker embedding: x_spk + # if self.use_query_embs: shape is (B, N_query, d_out) + # else: shape is (B, T, d_out) + + if self.in_linear != None: + # print('x_ref:',x_ref.shape) + x = self.in_linear(x_ref) + + x = self.transformer_encoder( + x, key_padding_mask=key_padding_mask, condition=None, diffusion_step=None + ) # B, T, d_out + + if self.use_query_emb: + spk_query_emb = self.query_embs( + torch.arange(self.num_query_emb).to(x.device) + ).repeat(x.shape[0], 1, 1) + #k/v b x t x d + #q b x n x d + spk_embs, _ = self.query_attn( + query=spk_query_emb, + key=x, + value=x, + key_padding_mask=( + ~(key_padding_mask.bool()) if key_padding_mask is not None else None + ), + )# B, N_query, d_out + if self.out_linear != None: + spk_embs = self.out_linear(spk_embs) + + else: + spk_query_emb = None + # B x n x d + # b x t x d + return spk_embs, x + + +def pad(input_ele, mel_max_length=None): + if mel_max_length: + max_len = mel_max_length + else: + max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) + + out_list = list() + for i, batch in enumerate(input_ele): + if len(batch.shape) == 1: + one_batch_padded = F.pad( + batch, (0, max_len - batch.size(0)), "constant", 0.0 + ) + elif len(batch.shape) == 2: + one_batch_padded = F.pad( + batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 + ) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +class FiLM(nn.Module): + def __init__(self, in_dim, cond_dim): + super().__init__() + + self.gain = Linear(cond_dim, in_dim) + self.bias = Linear(cond_dim, in_dim) + + nn.init.xavier_uniform_(self.gain.weight) + nn.init.constant_(self.gain.bias, 1) + + nn.init.xavier_uniform_(self.bias.weight) + nn.init.constant_(self.bias.bias, 0) + + def forward(self, x, condition): + gain = self.gain(condition) + bias = self.bias(condition) + if gain.dim() == 2: + gain = gain.unsqueeze(-1) + if bias.dim() == 2: + bias = bias.unsqueeze(-1) + return x * gain + bias + + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + layer.weight.data.normal_(0.0, 0.02) + return layer + + +def Linear(*args, **kwargs): + layer = nn.Linear(*args, **kwargs) + layer.weight.data.normal_(0.0, 0.02) + return layer + + +class ResidualBlock(nn.Module): + def __init__(self, hidden_dim, attn_head, dilation, drop_out, has_cattn=False): + super().__init__() + + self.hidden_dim = hidden_dim + self.dilation = dilation + self.has_cattn = has_cattn + self.attn_head = attn_head + self.drop_out = drop_out + + self.dilated_conv = Conv1d( + hidden_dim, 2 * hidden_dim, 3, padding=dilation, dilation=dilation + ) + self.diffusion_proj = Linear(hidden_dim, hidden_dim) + + self.cond_proj = Conv1d(hidden_dim, hidden_dim * 2, 1) + self.out_proj = Conv1d(hidden_dim, hidden_dim * 2, 1) + + if self.has_cattn: + self.attn = nn.MultiheadAttention( + hidden_dim, attn_head, 0.1, batch_first=True + ) + self.film = FiLM(hidden_dim * 2, hidden_dim) + + self.ln = nn.LayerNorm(hidden_dim) + + self.dropout = nn.Dropout(self.drop_out) + + def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb): + diffusion_step = self.diffusion_proj(diffusion_step).unsqueeze(-1) # (B, d, 1) + cond = self.cond_proj(cond) # (B, 2*d, T) + + y = x + diffusion_step + if x_mask != None: + y = y * x_mask.to(y.dtype)[:, None, :] # (B, 2*d, T) + + if self.has_cattn: + y_ = y.transpose(1, 2) + y_ = self.ln(y_) + + y_, _ = self.attn(y_, spk_query_emb, spk_query_emb) # (B, T, d) + + y = self.dilated_conv(y) + cond # (B, 2*d, T) + + if self.has_cattn: + y = self.film(y.transpose(1, 2), y_) # (B, T, 2*d) + y = y.transpose(1, 2) # (B, 2*d, T) + + gate, filter_ = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter_) + + y = self.out_proj(y) + + residual, skip = torch.chunk(y, 2, dim=1) + + if x_mask != None: + residual = residual * x_mask.to(y.dtype)[:, None, :] + skip = skip * x_mask.to(y.dtype)[:, None, :] + + return (x + residual) / math.sqrt(2.0), skip + + +class DiffWaveNet(nn.Module): + def __init__( + self, + cfg=None, + ): + super().__init__() + + self.cfg = cfg + self.in_dim = cfg.input_size + self.hidden_dim = cfg.hidden_size + self.out_dim = cfg.out_size + self.num_layers = cfg.num_layers + self.cross_attn_per_layer = cfg.cross_attn_per_layer + self.dilation_cycle = cfg.dilation_cycle + self.attn_head = cfg.attn_head + self.drop_out = cfg.drop_out + + self.in_proj = Conv1d(self.in_dim, self.hidden_dim, 1) + self.diffusion_embedding = SinusoidalPosEmb(self.hidden_dim) + + self.mlp = nn.Sequential( + Linear(self.hidden_dim, self.hidden_dim * 4), + Mish(), + Linear(self.hidden_dim * 4, self.hidden_dim), + ) + + self.cond_ln = nn.LayerNorm(self.hidden_dim) + + self.layers = nn.ModuleList( + [ + ResidualBlock( + self.hidden_dim, + self.attn_head, + 2 ** (i % self.dilation_cycle), + self.drop_out, + has_cattn=(i % self.cross_attn_per_layer == 0), + ) + for i in range(self.num_layers) + ] + ) + + self.skip_proj = Conv1d(self.hidden_dim, self.hidden_dim, 1) + self.out_proj = Conv1d(self.hidden_dim, self.out_dim, 1) + + nn.init.zeros_(self.out_proj.weight) + + def forward( + self, + x, + condition_embedding, + key_padding_mask=None, + reference_embedding=None, + diffusion_step=None, + ): + x = x.transpose(1, 2) # (B, T, d) -> (B, d, T) + x_mask = key_padding_mask + cond = condition_embedding + spk_query_emb = reference_embedding + diffusion_step = diffusion_step + + cond = self.cond_ln(cond) + cond_input = cond.transpose(1, 2) + + x_input = self.in_proj(x) + + x_input = F.relu(x_input) + + diffusion_step = self.diffusion_embedding(diffusion_step).to(x.dtype) + diffusion_step = self.mlp(diffusion_step) + + skip = [] + for _, layer in enumerate(self.layers): + x_input, skip_connection = layer( + x_input, x_mask, cond_input, diffusion_step, spk_query_emb + ) + skip.append(skip_connection) + + x_input = torch.sum(torch.stack(skip), dim=0) / math.sqrt(self.num_layers) + + x_out = self.skip_proj(x_input) + + x_out = F.relu(x_out) + + x_out = self.out_proj(x_out) # (B, 80, T) + + x_out = x_out.transpose(1, 2) + + return x_out + + +def override_config(base_config, new_config): + """Update new configurations in the original dict with the new dict + + Args: + base_config (dict): original dict to be overridden + new_config (dict): dict with new configurations + + Returns: + dict: updated configuration dict + """ + for k, v in new_config.items(): + if type(v) == dict: + if k not in base_config.keys(): + base_config[k] = {} + base_config[k] = override_config(base_config[k], v) + else: + base_config[k] = v + return base_config + + +def get_lowercase_keys_config(cfg): + """Change all keys in cfg to lower case + + Args: + cfg (dict): dictionary that stores configurations + + Returns: + dict: dictionary that stores configurations + """ + updated_cfg = dict() + for k, v in cfg.items(): + if type(v) == dict: + v = get_lowercase_keys_config(v) + updated_cfg[k.lower()] = v + return updated_cfg + + + +def save_config(save_path, cfg): + """Save configurations into a json file + + Args: + save_path (str): path to save configurations + cfg (dict): dictionary that stores configurations + """ + with open(save_path, "w") as f: + json5.dump( + cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True + ) + + +class JsonHParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = JsonHParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +class Noro_VCmodel(nn.Module): + def __init__(self, cfg, use_ref_noise = False): + super().__init__() + self.cfg = cfg + self.use_ref_noise = use_ref_noise + self.reference_encoder = ReferenceEncoder(cfg=cfg.reference_encoder) + if cfg.diffusion.diff_model_type == "WaveNet": + self.diffusion = Diffusion( + cfg=cfg.diffusion, + diff_model=DiffWaveNet(cfg=cfg.diffusion.diff_wavenet), + ) + else: + raise NotImplementedError() + pitch_dim = 1 + self.content_f0_enc = nn.Sequential( + nn.LayerNorm( + cfg.vc_feature.content_feature_dim + pitch_dim + ), # 768 (for mhubert) + 1 (for f0) + Rearrange("b t d -> b d t"), + nn.Conv1d( + cfg.vc_feature.content_feature_dim + pitch_dim, + cfg.vc_feature.hidden_dim, + kernel_size=3, + padding=1, + ), + Rearrange("b d t -> b t d"), + ) + + self.reset_parameters() + def forward( + self, + x=None, + content_feature=None, + pitch=None, + x_ref=None, + x_mask=None, + x_ref_mask=None, + noisy_x_ref=None + ): + noisy_reference_embedding = None + noisy_condition_embedding = None + + reference_embedding, encoded_x = self.reference_encoder( + x_ref=x_ref, key_padding_mask=x_ref_mask + ) + + # content_feature: B x T x D + # pitch: B x T x 1 + # B x t x D+1 + # 2B x T + condition_embedding = torch.cat([content_feature, pitch[:, :, None]], dim=-1) + condition_embedding = self.content_f0_enc(condition_embedding) + + # 2B x T x D + if self.use_ref_noise: + # noisy_reference + noisy_reference_embedding, _ = self.reference_encoder( + x_ref=noisy_x_ref, key_padding_mask=x_ref_mask + ) + combined_reference_embedding = (noisy_reference_embedding + reference_embedding) / 2 + else: + combined_reference_embedding = reference_embedding + + combined_condition_embedding = condition_embedding + + diff_out = self.diffusion( + x=x, + condition_embedding=combined_condition_embedding, + x_mask=x_mask, + reference_embedding=combined_reference_embedding, + ) + return diff_out, (reference_embedding, noisy_reference_embedding), (condition_embedding, noisy_condition_embedding) + + @torch.no_grad() + def inference( + self, + content_feature=None, + pitch=None, + x_ref=None, + x_ref_mask=None, + inference_steps=1000, + sigma=1.2, + ): + reference_embedding, _ = self.reference_encoder( + x_ref=x_ref, key_padding_mask=x_ref_mask + ) + + condition_embedding = torch.cat([content_feature, pitch[:, :, None]], dim=-1) + condition_embedding = self.content_f0_enc(condition_embedding) + + bsz, l, _ = condition_embedding.shape + if self.cfg.diffusion.diff_model_type == "WaveNet": + z = ( + torch.randn(bsz, l, self.cfg.diffusion.diff_wavenet.input_size).to( + condition_embedding.device + ) + / sigma + ) + + x0 = self.diffusion.reverse_diffusion( + z=z, + condition_embedding=condition_embedding, + x_mask=None, + reference_embedding=reference_embedding, + n_timesteps=inference_steps, + ) + + return x0 + + def reset_parameters(self): + def _reset_parameters(m): + if isinstance(m, nn.MultiheadAttention): + if m._qkv_same_embed_dim: + nn.init.normal_(m.in_proj_weight, std=0.02) + else: + nn.init.normal_(m.q_proj_weight, std=0.02) + nn.init.normal_(m.k_proj_weight, std=0.02) + nn.init.normal_(m.v_proj_weight, std=0.02) + + if m.in_proj_bias is not None: + nn.init.constant_(m.in_proj_bias, 0.0) + nn.init.constant_(m.out_proj.bias, 0.0) + if m.bias_k is not None: + nn.init.xavier_normal_(m.bias_k) + if m.bias_v is not None: + nn.init.xavier_normal_(m.bias_v) + + elif ( + isinstance(m, nn.Conv1d) + or isinstance(m, nn.ConvTranspose1d) + or isinstance(m, nn.Conv2d) + or isinstance(m, nn.ConvTranspose2d) + ): + m.weight.data.normal_(0.0, 0.02) + + elif isinstance(m, nn.Linear): + m.weight.data.normal_(mean=0.0, std=0.02) + if m.bias is not None: + m.bias.data.zero_() + + elif isinstance(m, nn.Embedding): + m.weight.data.normal_(mean=0.0, std=0.02) + if m.padding_idx is not None: + m.weight.data[m.padding_idx].zero_() + + self.apply(_reset_parameters) \ No newline at end of file diff --git a/models/vc/Noro/noro_trainer.py b/models/vc/Noro/noro_trainer.py new file mode 100644 index 00000000..79de602a --- /dev/null +++ b/models/vc/Noro/noro_trainer.py @@ -0,0 +1,458 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import time +import json5 +import torch +import numpy as np +from tqdm import tqdm +from utils.util import ValueWindow +from torch.utils.data import DataLoader +from models.vc.Noro.noro_base_trainer import Noro_base_Trainer +from torch.nn import functional as F +from models.base.base_sampler import VariableSampler + +from diffusers import get_scheduler +import accelerate +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from models.vc.Noro.noro_model import Noro_VCmodel +from models.vc.Noro.noro_dataset import VCCollator, VCDataset, batch_by_size +from processors.content_extractor import HubertExtractor +from models.vc.Noro.noro_loss import diff_loss, ConstractiveSpeakerLoss +from utils.mel import mel_spectrogram_torch +from utils.f0 import get_f0_features_using_dio, interpolate +from torch.nn.utils.rnn import pad_sequence + + +class NoroTrainer(Noro_base_Trainer): + def __init__(self, args, cfg): + self.args = args + self.cfg = cfg + cfg.exp_name = args.exp_name + self.content_extractor = "mhubert" + + # Initialize accelerator and ensure all processes are ready + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Initialize logger on the main process + if self.accelerator.is_main_process: + self.logger = get_logger(args.exp_name, log_level="INFO") + + # Configure noise and speaker usage + self.use_ref_noise = self.cfg.trans_exp.use_ref_noise + + # Log configuration on the main process + if self.accelerator.is_main_process: + self.logger.info(f"use_ref_noise: {self.use_ref_noise}") + + # Initialize a time window for monitoring metrics + self.time_window = ValueWindow(50) + + # Log the start of training + if self.accelerator.is_main_process: + self.logger.info("=" * 56) + self.logger.info("||\t\tNew training process started.\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + + # Initialize checkpoint directory + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Initialize training counters + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + self.max_epoch = self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + if self.accelerator.is_main_process: + self.logger.info(f"Max epoch: {self.max_epoch if self.max_epoch < float('inf') else 'Unlimited'}") + + # Check basic configuration + if self.accelerator.is_main_process: + self._check_basic_configs() + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.keep_last = [i if i > 0 else float("inf") for i in self.cfg.train.keep_last] + self.run_eval = self.cfg.train.run_eval + + # Set random seed + with self.accelerator.main_process_first(): + self._set_random_seed(self.cfg.train.random_seed) + + # Setup data loader + with self.accelerator.main_process_first(): + if self.accelerator.is_main_process: + self.logger.info("Building dataset...") + self.train_dataloader = self._build_dataloader() + self.speaker_num = len(self.train_dataloader.dataset.speaker2id) + if self.accelerator.is_main_process: + self.logger.info("Speaker num: {}".format(self.speaker_num)) + + # Build model + with self.accelerator.main_process_first(): + if self.accelerator.is_main_process: + self.logger.info("Building model...") + self.model, self.w2v = self._build_model() + + # Resume training if specified + with self.accelerator.main_process_first(): + if self.accelerator.is_main_process: + self.logger.info("Resume training: {}".format(args.resume)) + if args.resume: + if self.accelerator.is_main_process: + self.logger.info("Resuming from checkpoint...") + ckpt_path = self._load_model( + self.checkpoint_dir, + args.checkpoint_path, + resume_type=args.resume_type, + ) + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + if self.accelerator.is_main_process: + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Initialize optimizer & scheduler + with self.accelerator.main_process_first(): + if self.accelerator.is_main_process: + self.logger.info("Building optimizer and scheduler...") + self.optimizer = self._build_optimizer() + self.scheduler = self._build_scheduler() + + # Prepare model, w2v, optimizer, and scheduler for accelerator + self.model = self._prepare_for_accelerator(self.model) + self.w2v = self._prepare_for_accelerator(self.w2v) + self.optimizer = self._prepare_for_accelerator(self.optimizer) + self.scheduler = self._prepare_for_accelerator(self.scheduler) + + # Build criterion + with self.accelerator.main_process_first(): + if self.accelerator.is_main_process: + self.logger.info("Building criterion...") + self.criterion = self._build_criterion() + + self.config_save_path = os.path.join(self.exp_dir, "args.json") + self.task_type = "VC" + self.contrastive_speaker_loss = ConstractiveSpeakerLoss() + + if self.accelerator.is_main_process: + self.logger.info("Task type: {}".format(self.task_type)) + + def _init_accelerator(self): + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, + logging_dir=os.path.join(self.exp_dir, "log"), + ) + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + self.accelerator.wait_for_everyone() + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def _build_model(self): + w2v = HubertExtractor(self.cfg) + model = Noro_VCmodel(cfg=self.cfg.model, use_ref_noise=self.use_ref_noise) + return model, w2v + + def _build_dataloader(self): + np.random.seed(int(time.time())) + if self.accelerator.is_main_process: + self.logger.info("Use Dynamic Batchsize...") + train_dataset = VCDataset(self.cfg.trans_exp) + train_collate = VCCollator(self.cfg) + batch_sampler = batch_by_size( + train_dataset.num_frame_indices, + train_dataset.get_num_frames, + max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, + max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, + required_batch_size_multiple=self.accelerator.num_processes, + ) + np.random.shuffle(batch_sampler) + batches = [ + x[self.accelerator.local_process_index :: self.accelerator.num_processes] + for x in batch_sampler + if len(x) % self.accelerator.num_processes == 0 + ] + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + num_workers=self.cfg.train.dataloader.num_worker, + batch_sampler=VariableSampler( + batches, drop_last=False, use_random_sampler=True + ), + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + self.accelerator.wait_for_everyone() + return train_loader + + def _build_optimizer(self): + optimizer = torch.optim.AdamW( + filter(lambda p: p.requires_grad, self.model.parameters()), + **self.cfg.train.adam, + ) + return optimizer + + def _build_scheduler(self): + lr_scheduler = get_scheduler( + self.cfg.train.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=self.cfg.train.lr_warmup_steps, + num_training_steps=self.cfg.train.num_train_steps + ) + return lr_scheduler + + def _build_criterion(self): + criterion = torch.nn.L1Loss(reduction="mean") + return criterion + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + def load_model(self, checkpoint): + self.step = checkpoint["step"] + self.epoch = checkpoint["epoch"] + self.model.load_state_dict(checkpoint["model"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.scheduler.load_state_dict(checkpoint["scheduler"]) + + def _prepare_for_accelerator(self, component): + if isinstance(component, dict): + for key in component.keys(): + component[key] = self.accelerator.prepare(component[key]) + else: + component = self.accelerator.prepare(component) + return component + + def _train_step(self, batch): + total_loss = 0.0 + train_losses = {} + device = self.accelerator.device + + # Move all Tensor data to the specified device + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + speech = batch["speech"] + ref_speech = batch["ref_speech"] + + with torch.set_grad_enabled(False): + # Extract features and spectrograms + mel = mel_spectrogram_torch(speech, self.cfg).transpose(1, 2) + ref_mel = mel_spectrogram_torch(ref_speech,self.cfg).transpose(1, 2) + mask = batch["mask"] + ref_mask = batch["ref_mask"] + + # Extract pitch and content features + audio = speech.cpu().numpy() + f0s = [] + for i in range(audio.shape[0]): + wav = audio[i] + f0 = get_f0_features_using_dio(wav, self.cfg.preprocess) + f0, _ = interpolate(f0) + frame_num = len(wav) // self.cfg.preprocess.hop_size + f0 = torch.from_numpy(f0[:frame_num]).to(speech.device) + f0s.append(f0) + + pitch = pad_sequence(f0s, batch_first=True, padding_value=0).float() + pitch = (pitch - pitch.mean(dim=1, keepdim=True)) / (pitch.std(dim=1, keepdim=True) + 1e-6) # Normalize pitch (B,T) + _, content_feature = self.w2v.extract_content_features(speech) # semantic (B, T, 768) + + if self.use_ref_noise: + noisy_ref_mel = mel_spectrogram_torch(batch["noisy_ref_speech"], self.cfg).transpose(1, 2) + + if self.use_ref_noise: + diff_out, (ref_emb, noisy_ref_emb), (cond_emb, _) = self.model( + x=mel, content_feature=content_feature, pitch=pitch, x_ref=ref_mel, + x_mask=mask, x_ref_mask=ref_mask, noisy_x_ref=noisy_ref_mel + ) + else: + diff_out, (ref_emb, _), (cond_emb, _) = self.model( + x=mel, content_feature=content_feature, pitch=pitch, x_ref=ref_mel, + x_mask=mask, x_ref_mask=ref_mask + ) + + if self.use_ref_noise: + # B x N_query x D + ref_emb = torch.mean(ref_emb, dim=1) # B x D + noisy_ref_emb = torch.mean(noisy_ref_emb, dim=1) # B x D + all_ref_emb = torch.cat([ref_emb, noisy_ref_emb], dim=0) # 2B x D + all_speaker_ids = torch.cat([batch["speaker_id"], batch["speaker_id"]], dim=0) # 2B + cs_loss = self.contrastive_speaker_loss(all_ref_emb, all_speaker_ids) * 0.25 + total_loss += cs_loss + train_losses["ref_loss"] = cs_loss + + diff_loss_x0 = diff_loss(diff_out["x0_pred"], mel, mask=mask) + total_loss += diff_loss_x0 + train_losses["diff_loss_x0"] = diff_loss_x0 + + diff_loss_noise = diff_loss(diff_out["noise_pred"], diff_out["noise"], mask=mask) + total_loss += diff_loss_noise + train_losses["diff_loss_noise"] = diff_loss_noise + train_losses["total_loss"] = total_loss + + self.optimizer.zero_grad() + self.accelerator.backward(total_loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()), 0.5) + self.optimizer.step() + self.scheduler.step() + + for item in train_losses: + train_losses[item] = train_losses[item].item() + + train_losses['learning_rate'] = f"{self.optimizer.param_groups[0]['lr']:.1e}" + train_losses["batch_size"] = batch["speaker_id"].shape[0] + + return (train_losses["total_loss"], train_losses, None) + + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key].train() + else: + self.model.train() + if isinstance(self.w2v, dict): + for key in self.w2v.keys(): + self.w2v[key].eval() + else: + self.w2v.eval() + + epoch_sum_loss: float = 0.0 # total loss + # Put the data to cuda device + device = self.accelerator.device + with device: + torch.cuda.empty_cache() + self.model = self.model.to(device) + self.w2v = self.w2v.to(device) + + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + speech = batch["speech"].cpu().numpy() + speech = speech[0] + self.batch_count += 1 + self.step += 1 + if len(speech) >= 16000 * 25: + continue + with self.accelerator.accumulate(self.model): + total_loss, train_losses, _ = self._train_step(batch) + + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + epoch_sum_loss += total_loss + self.current_loss = total_loss + if isinstance(train_losses, dict): + for key, loss in train_losses.items(): + self.accelerator.log( + {"Epoch/Train {} Loss".format(key): loss}, + step=self.step, + ) + if self.accelerator.is_main_process and self.batch_count % 10 == 0: + self.echo_log(train_losses, mode="Training") + + self.save_checkpoint() + self.accelerator.wait_for_everyone() + + return epoch_sum_loss, None + + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # Dump config file + if self.accelerator.is_main_process: + self._dump_cfg(self.config_save_path) + + # Wait to ensure good to go + self.accelerator.wait_for_everyone() + # Stop when meeting max epoch or self.cfg.train.num_train_steps + while self.epoch < self.max_epoch and self.step < self.cfg.train.num_train_steps: + if self.accelerator.is_main_process: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + self.logger.info("Start training...") + + train_total_loss, _ = self._train_epoch() + + self.epoch += 1 + self.accelerator.wait_for_everyone() + if isinstance(self.scheduler, dict): + for key in self.scheduler.keys(): + self.scheduler[key].step() + else: + self.scheduler.step() + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, train_total_loss + ), + ) + ) + self.accelerator.end_training() + if self.accelerator.is_main_process: + self.logger.info("Training finished...") + + def save_checkpoint(self): + self.accelerator.wait_for_everyone() + # Main process only + if self.accelerator.is_main_process: + if self.batch_count % self.save_checkpoint_stride[0] == 0: + keep_last = self.keep_last[0] + # Read all folders in self.checkpoint_dir + all_ckpts = os.listdir(self.checkpoint_dir) + # Exclude non-folders + all_ckpts = [ckpt for ckpt in all_ckpts if os.path.isdir(os.path.join(self.checkpoint_dir, ckpt))] + if len(all_ckpts) > keep_last: + # Keep only the last keep_last folders in self.checkpoint_dir, sorted by step "epoch-{:04d}_step-{:07d}_loss-{:.6f}" + all_ckpts = sorted(all_ckpts, key=lambda x: int(x.split("_")[1].split('-')[1])) + for ckpt in all_ckpts[:-keep_last]: + shutil.rmtree(os.path.join(self.checkpoint_dir, ckpt)) + checkpoint_filename = "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, self.current_loss + ) + path = os.path.join(self.checkpoint_dir, checkpoint_filename) + self.logger.info("Saving state to {}...".format(path)) + self.accelerator.save_state(path) + self.logger.info("Finished saving state.") + self.accelerator.wait_for_everyone() diff --git a/processors/audio_features_extractor.py b/processors/audio_features_extractor.py index 8e38bd5e..0018fcf6 100644 --- a/processors/audio_features_extractor.py +++ b/processors/audio_features_extractor.py @@ -25,6 +25,7 @@ WhisperExtractor, ContentvecExtractor, WenetExtractor, + HubertExtractor ) @@ -155,3 +156,20 @@ def get_wenet_features(self, wavs, target_frame_len, wav_lens=None): wenet_feats = self.wenet_extractor.extract_content_features(wavs, lens=wav_lens) wenet_feats = self.wenet_extractor.ReTrans(wenet_feats, target_frame_len) return wenet_feats + + def get_hubert_features(self, wavs): + """Get HuBERT Features + + Args: + wavs: Tensor whose shape is (B, T) + + Returns: + Tensor whose shape is (B, T, D) + """ + if not hasattr(self, "model"): + self.hubert_extractor = HubertExtractor(self.cfg) + + clusters, hubert_feats = self.hubert_extractor.extract_content_features(wavs) + + + return clusters, hubert_feats diff --git a/processors/content_extractor.py b/processors/content_extractor.py index 34b54917..c45e587a 100644 --- a/processors/content_extractor.py +++ b/processors/content_extractor.py @@ -14,6 +14,10 @@ from torch.utils.data import DataLoader from fairseq import checkpoint_utils from transformers import AutoModel, Wav2Vec2FeatureExtractor +import torch.nn as nn +import torch.nn.functional as F +import joblib +from einops import repeat from utils.io_optim import ( TorchaudioDataset, @@ -493,6 +497,60 @@ def extract_content_features(self, wavs): mert_features.append(feature) return mert_features + +class HubertExtractor(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.load_model() + + def load_model(self): + kmeans_model_path = self.cfg.preprocess.kmeans_model_path + hubert_model_path = self.cfg.preprocess.hubert_model_path + print("Load Hubert Model...") + checkpoint = torch.load(hubert_model_path) + load_model_input = {hubert_model_path: checkpoint} + model, *_ = checkpoint_utils.load_model_ensemble_and_task( + load_model_input + ) + self.model = model[0] + self.model.eval() + + # Load KMeans cluster centers + kmeans = joblib.load(kmeans_model_path) + self.kmeans = kmeans + + self.register_buffer( + "cluster_centers", torch.from_numpy(kmeans.cluster_centers_) + ) + + def extract_content_features(self, wav_input): + """ + Extract content features and quantize using KMeans clustering. + + Args: + audio_data: tensor (batch_size, T) + + Returns: + quantize: tensor (batch_size, T, 768) + """ + # Extract features using HuBERT + wav_input = F.pad(wav_input, (40, 40), "reflect") + embed = self.model( + wav_input, + features_only=True, + mask=False, + ) + + batched_cluster_centers = repeat( + self.cluster_centers, "c d -> b c d", b=embed.shape[0] + ) + + dists = -torch.cdist(embed, batched_cluster_centers, p=2) + clusters = dists.argmax(dim=-1) # (batch, seq_len) + quantize = F.embedding(clusters, self.cluster_centers) + + return clusters, quantize def extract_utt_content_features_dataloader(cfg, metadata, num_workers):