diff --git a/README.md b/README.md
index fccbae6e..c288ca61 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,10 @@
In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.
## 🚀 News
+- **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911) and [Emilia](https://arxiv.org/abs/2407.05361) got accepted by IEEE SLT 2024! 🤗
+- **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.com/invite/ZxxREr3Y) to stay connected and engage with our community!
+- **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉
+- **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
- **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md)
- **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md)
- **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md)
diff --git a/bins/vc/Noro/train.py b/bins/vc/Noro/train.py
new file mode 100644
index 00000000..8418c5cd
--- /dev/null
+++ b/bins/vc/Noro/train.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+
+import torch
+from models.vc.Noro.noro_trainer import NoroTrainer
+from utils.util import load_config
+
+def build_trainer(args, cfg):
+ supported_trainer = {
+ "VC": NoroTrainer,
+ }
+ trainer_class = supported_trainer[cfg.model_type]
+ trainer = trainer_class(args, cfg)
+ return trainer
+
+
+def cuda_relevant(deterministic=False):
+ torch.cuda.empty_cache()
+ # TF32 on Ampere and above
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.enabled = True
+ torch.backends.cudnn.allow_tf32 = True
+ # Deterministic
+ torch.backends.cudnn.deterministic = deterministic
+ torch.backends.cudnn.benchmark = not deterministic
+ torch.use_deterministic_algorithms(deterministic)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config",
+ default="config.json",
+ help="json files for configurations.",
+ required=True,
+ )
+ parser.add_argument(
+ "--exp_name",
+ type=str,
+ default="exp_name",
+ help="A specific name to note the experiment",
+ required=True,
+ )
+ parser.add_argument(
+ "--resume", action="store_true", help="The model name to restore"
+ )
+ parser.add_argument(
+ "--log_level", default="warning", help="logging level (debug, info, warning)"
+ )
+ parser.add_argument(
+ "--resume_type",
+ type=str,
+ default="resume",
+ help="Resume training or finetuning.",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ default=None,
+ help="Checkpoint for resume training or finetuning.",
+ )
+ NoroTrainer.add_arguments(parser)
+ args = parser.parse_args()
+ cfg = load_config(args.config)
+ print("experiment name: ", args.exp_name)
+ # # CUDA settings
+ cuda_relevant()
+ # Build trainer
+ print(f"Building {cfg.model_type} trainer")
+ trainer = build_trainer(args, cfg)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ print(f"Start training {cfg.model_type} model")
+ trainer.train_loop()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/config/noro.json b/config/noro.json
new file mode 100644
index 00000000..3c9d5c12
--- /dev/null
+++ b/config/noro.json
@@ -0,0 +1,76 @@
+{
+ "base_config": "config/base.json",
+ "model_type": "VC",
+ "dataset": ["mls"],
+ "model": {
+ "reference_encoder": {
+ "encoder_layer": 6,
+ "encoder_hidden": 512,
+ "encoder_head": 8,
+ "conv_filter_size": 2048,
+ "conv_kernel_size": 9,
+ "encoder_dropout": 0.2,
+ "use_skip_connection": false,
+ "use_new_ffn": true,
+ "ref_in_dim": 80,
+ "ref_out_dim": 512,
+ "use_query_emb": true,
+ "num_query_emb": 32
+ },
+ "diffusion": {
+ "beta_min": 0.05,
+ "beta_max": 20,
+ "sigma": 1.0,
+ "noise_factor": 1.0,
+ "ode_solve_method": "euler",
+ "diff_model_type": "WaveNet",
+ "diff_wavenet":{
+ "input_size": 80,
+ "hidden_size": 512,
+ "out_size": 80,
+ "num_layers": 47,
+ "cross_attn_per_layer": 3,
+ "dilation_cycle": 2,
+ "attn_head": 8,
+ "drop_out": 0.2
+ }
+ },
+ "prior_encoder": {
+ "encoder_layer": 6,
+ "encoder_hidden": 512,
+ "encoder_head": 8,
+ "conv_filter_size": 2048,
+ "conv_kernel_size": 9,
+ "encoder_dropout": 0.2,
+ "use_skip_connection": false,
+ "use_new_ffn": true,
+ "vocab_size": 256,
+ "cond_dim": 512,
+ "duration_predictor": {
+ "input_size": 512,
+ "filter_size": 512,
+ "kernel_size": 3,
+ "conv_layers": 30,
+ "cross_attn_per_layer": 3,
+ "attn_head": 8,
+ "drop_out": 0.2
+ },
+ "pitch_predictor": {
+ "input_size": 512,
+ "filter_size": 512,
+ "kernel_size": 5,
+ "conv_layers": 30,
+ "cross_attn_per_layer": 3,
+ "attn_head": 8,
+ "drop_out": 0.5
+ },
+ "pitch_min": 50,
+ "pitch_max": 1100,
+ "pitch_bins_num": 512
+ },
+ "vc_feature": {
+ "content_feature_dim": 768,
+ "hidden_dim": 512
+ }
+ }
+}
\ No newline at end of file
diff --git a/config/tts.json b/config/tts.json
index 882726db..0e804bda 100644
--- a/config/tts.json
+++ b/config/tts.json
@@ -19,7 +19,7 @@
"add_blank": true
},
"model": {
- "text_token_num": 512,
+ "text_token_num": 512
}
}
diff --git a/egs/vc/Noro/README.md b/egs/vc/Noro/README.md
new file mode 100644
index 00000000..55c1340f
--- /dev/null
+++ b/egs/vc/Noro/README.md
@@ -0,0 +1,122 @@
+# Noro: A Noise-Robust One-shot Voice Conversion System
+
+
+
+
+
+
+
+This is the official implementation of the paper: NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities.
+
+- The semantic extractor is from [Hubert](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert).
+- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture.
+
+## Project Overview
+Noro is a noise-robust one-shot voice conversion (VC) system designed to convert the timbre of speech from a source speaker to a target speaker using only a single reference speech sample, while preserving the semantic content of the original speech. Noro introduces innovative components tailored for VC using noisy reference speeches, including a dual-branch reference encoding module and a noise-agnostic contrastive speaker loss.
+
+## Features
+- **Noise-Robust Voice Conversion**: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments.
+- **One-shot Voice Conversion**: Achieves timbre conversion using only one reference speech sample.
+- **Speaker Representation Learning**: Explores the potential of the reference encoder as a self-supervised speaker encoder.
+
+## Installation Requirement
+
+Set up your environment as in Amphion README (you'll need a conda environment, and we recommend using Linux).
+
+### Prepare Hubert Model
+
+Humbert checkpoint and kmeans can be downloaded [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert).
+Set the downloded model path at `egs/vc/Noro/exp_config_base.json`.
+
+
+## Usage
+
+### Download pretrained weights
+You need to download our pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1NPzSIuSKO8o87g5ImNzpw_BgbhsZaxNg?usp=drive_link).
+
+### Inference
+1. Configure inference parameters:
+ Modify the pretrained checkpoint path, source voice path and reference voice path at `egs/vc/Noro/noro_inference.sh` file.
+ Currently it's at line 35.
+```
+ checkpoint_path="path/to/checkpoint/model.safetensors"
+ output_dir="path/to/output/directory"
+ source_path="path/to/source/audio.wav"
+ reference_path="path/to/reference/audio.wav"
+```
+2. Start inference:
+ ```bash
+ bash path/to/Amphion/egs/vc/noro_inference.sh
+ ```
+
+3. You got the reconstructed mel spectrum saved to the output direction.
+ Then use the [BigVGAN](https://github.com/NVIDIA/BigVGAN) to construct the wav file.
+
+## Training from Scratch
+
+### Data Preparation
+
+We use the LibriLight dataset for training and evaluation. You can download it using the following commands:
+```bash
+ wget https://dl.fbaipublicfiles.com/librilight/data/large.tar
+ wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar
+ wget https://dl.fbaipublicfiles.com/librilight/data/small.tar
+```
+
+### Training the Model with Clean Reference Voice
+
+Configure training parameters:
+Our configuration file for training clean Noro model is at "egs/vc/Noro/exp_config_clean.json", and Nosiy Noro model at "egs/vc/Noro/exp_config_noisy.json".
+
+To train your model, you need to modify the `dataset` variable in the json configurations.
+Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory.
+```
+ "directory_list": [
+ "path/to/your/training_data_directory1",
+ "path/to/your/training_data_directory2",
+ "path/to/your/training_data_directory3"
+ ],
+```
+
+If you want to train for the noisy noro model, you also need to set the direction path for the noisy data at "egs/vc/Noro/exp_config_noisy.json".
+```
+ "noise_dir": "path/to/your/noise/train/directory",
+ "test_noise_dir": "path/to/your/noise/test/directory"
+```
+
+You can change other experiment settings in the config flies such as the learning rate, optimizer and the dataset.
+
+ **Set smaller batch_size if you are out of memory😢😢**
+
+I used max_tokens = 3200000 to successfully run on a single card, if you'r out of memory, try smaller.
+
+```json
+ "max_tokens": 3200000
+```
+### Resume from existing checkpoint
+Our framework supports resuming from existing checkpoint.
+If this is a new experiment, use the following command:
+```
+CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
+"${work_dir}/bins/vc/train.py" \
+ --config $exp_config \
+ --exp_name $exp_name \
+ --log_level debug
+```
+To resume training or fine-tune from a checkpoint, use the following command:
+Ensure the options `--resume`, `--resume_type resume`, and `--checkpoint_path` are set.
+
+### Run the command to Train model
+Start clean training:
+ ```bash
+ bash path/to/Amphion/egs/vc/noro_train_clean.sh
+ ```
+
+
+Start noisy training:
+ ```bash
+ bash path/to/Amphion/egs/vc/noro_train_noisy.sh
+ ```
+
+
+
diff --git a/egs/vc/Noro/exp_config_base.json b/egs/vc/Noro/exp_config_base.json
new file mode 100644
index 00000000..27832d4d
--- /dev/null
+++ b/egs/vc/Noro/exp_config_base.json
@@ -0,0 +1,61 @@
+{
+ "base_config": "config/noro.json",
+ "model_type": "VC",
+ "dataset": [
+ "mls"
+ ],
+ "sample_rate": 16000,
+ "n_fft": 1024,
+ "n_mel": 80,
+ "hop_size": 200,
+ "win_size": 800,
+ "fmin": 0,
+ "fmax": 8000,
+ "preprocess": {
+ "kmeans_model_path": "path/to/kmeans_model",
+ "hubert_model_path": "path/to/hubert_model",
+ "sample_rate": 16000,
+ "hop_size": 200,
+ "f0_min": 50,
+ "f0_max": 500,
+ "frame_period": 12.5
+ },
+ "model": {
+ "reference_encoder": {
+ "encoder_layer": 6,
+ "encoder_hidden": 512,
+ "encoder_head": 8,
+ "conv_filter_size": 2048,
+ "conv_kernel_size": 9,
+ "encoder_dropout": 0.2,
+ "use_skip_connection": false,
+ "use_new_ffn": true,
+ "ref_in_dim": 80,
+ "ref_out_dim": 512,
+ "use_query_emb": true,
+ "num_query_emb": 32
+ },
+ "diffusion": {
+ "beta_min": 0.05,
+ "beta_max": 20,
+ "sigma": 1.0,
+ "noise_factor": 1.0,
+ "ode_solve_method": "euler",
+ "diff_model_type": "WaveNet",
+ "diff_wavenet":{
+ "input_size": 80,
+ "hidden_size": 512,
+ "out_size": 80,
+ "num_layers": 47,
+ "cross_attn_per_layer": 3,
+ "dilation_cycle": 2,
+ "attn_head": 8,
+ "drop_out": 0.2
+ }
+ },
+ "vc_feature": {
+ "content_feature_dim": 768,
+ "hidden_dim": 512
+ }
+ }
+}
\ No newline at end of file
diff --git a/egs/vc/Noro/exp_config_clean.json b/egs/vc/Noro/exp_config_clean.json
new file mode 100644
index 00000000..e0dbd367
--- /dev/null
+++ b/egs/vc/Noro/exp_config_clean.json
@@ -0,0 +1,38 @@
+{
+ "base_config": "egs/vc/exp_config_base.json",
+ "dataset": [
+ "mls"
+ ],
+ // Specify the output root path to save model checkpoints and logs
+ "log_dir": "path/to/your/checkpoint/directory",
+ "train": {
+ // New trainer and Accelerator
+ "gradient_accumulation_step": 1,
+ "tracker": ["tensorboard"],
+ "max_epoch": 10,
+ "save_checkpoint_stride": [1000],
+ "keep_last": [20],
+ "run_eval": [true],
+ "dataloader": {
+ "num_worker": 64,
+ "pin_memory": true
+ },
+ "adam": {
+ "lr": 5e-5
+ },
+ "use_dynamic_batchsize": true,
+ "max_tokens": 3200000,
+ "max_sentences": 64,
+ "lr_warmup_steps": 5000,
+ "lr_scheduler": "cosine",
+ "num_train_steps": 800000
+ },
+ "trans_exp": {
+ "directory_list": [
+ "path/to/your/training_data_directory1",
+ "path/to/your/training_data_directory2",
+ "path/to/your/training_data_directory3"
+ ],
+ "use_ref_noise": false
+ }
+}
\ No newline at end of file
diff --git a/egs/vc/Noro/exp_config_noisy.json b/egs/vc/Noro/exp_config_noisy.json
new file mode 100644
index 00000000..7e5d7e75
--- /dev/null
+++ b/egs/vc/Noro/exp_config_noisy.json
@@ -0,0 +1,40 @@
+{
+ "base_config": "egs/vc/exp_config_base.json",
+ "dataset": [
+ "mls"
+ ],
+ // Specify the output root path to save model checkpoints and logs
+ "log_dir": "path/to/your/checkpoint/directory",
+ "train": {
+ // New trainer and Accelerator
+ "gradient_accumulation_step": 1,
+ "tracker": ["tensorboard"],
+ "max_epoch": 10,
+ "save_checkpoint_stride": [1000],
+ "keep_last": [20],
+ "run_eval": [true],
+ "dataloader": {
+ "num_worker": 64,
+ "pin_memory": true
+ },
+ "adam": {
+ "lr": 5e-5
+ },
+ "use_dynamic_batchsize": true,
+ "max_tokens": 3200000,
+ "max_sentences": 64,
+ "lr_warmup_steps": 5000,
+ "lr_scheduler": "cosine",
+ "num_train_steps": 800000
+ },
+ "trans_exp": {
+ "directory_list": [
+ "path/to/your/training_data_directory1",
+ "path/to/your/training_data_directory2",
+ "path/to/your/training_data_directory3"
+ ],
+ "use_ref_noise": true,
+ "noise_dir": "path/to/your/noise/train/directory",
+ "test_noise_dir": "path/to/your/noise/test/directory"
+ }
+ }
\ No newline at end of file
diff --git a/egs/vc/Noro/noro_inference.sh b/egs/vc/Noro/noro_inference.sh
new file mode 100644
index 00000000..5f10a9cd
--- /dev/null
+++ b/egs/vc/Noro/noro_inference.sh
@@ -0,0 +1,54 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Set the PYTHONPATH to the current directory
+export PYTHONPATH="./"
+
+######## Build Experiment Environment ###########
+# Get the current directory of the script
+exp_dir=$(cd `dirname $0`; pwd)
+# Get the parent directory of the experiment directory
+work_dir=$(dirname $(dirname $exp_dir))
+
+# Export environment variables for the working directory and Python path
+export WORK_DIR=$work_dir
+export PYTHONPATH=$work_dir
+export PYTHONIOENCODING=UTF-8
+
+# Build the monotonic alignment module
+cd $work_dir/modules/monotonic_align
+mkdir -p monotonic_align
+python setup.py build_ext --inplace
+cd $work_dir
+
+if [ -z "$exp_config" ]; then
+ exp_config="${exp_dir}/exp_config_base.json"
+fi
+
+echo "Experimental Configuration File: $exp_config"
+
+cuda_id=0
+
+# Set paths (modify these paths to your own)
+checkpoint_path="path/to/checkpoint/model.safetensors"
+output_dir="path/to/output/directory"
+source_path="path/to/source/audio.wav"
+reference_path="path/to/reference/audio.wav"
+
+echo "CUDA ID: $cuda_id"
+echo "Checkpoint Path: $checkpoint_path"
+echo "Output Directory: $output_dir"
+echo "Source Audio Path: $source_path"
+echo "Reference Audio Path: $reference_path"
+
+# Run the voice conversion inference script
+python "${work_dir}/models/vc/noro_inference.py" \
+ --config $exp_config \
+ --checkpoint_path $checkpoint_path \
+ --output_dir $output_dir \
+ --cuda_id ${cuda_id} \
+ --source_path $source_path \
+ --ref_path $reference_path
+
diff --git a/egs/vc/Noro/noro_train_clean.sh b/egs/vc/Noro/noro_train_clean.sh
new file mode 100644
index 00000000..e04c16a8
--- /dev/null
+++ b/egs/vc/Noro/noro_train_clean.sh
@@ -0,0 +1,55 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+######## Build Experiment Environment ###########
+exp_dir=$(cd `dirname $0`; pwd)
+work_dir=$(dirname $(dirname $exp_dir))
+
+export WORK_DIR=$work_dir
+export PYTHONPATH=$work_dir
+export PYTHONIOENCODING=UTF-8
+
+cd $work_dir/modules/monotonic_align
+mkdir -p monotonic_align
+python setup.py build_ext --inplace
+cd $work_dir
+
+if [ -z "$exp_config" ]; then
+ exp_config="${exp_dir}/exp_config_clean.json"
+fi
+echo "Experimental Configuration File: $exp_config"
+
+# Set experiment name
+exp_name="experiment_name"
+
+# Set CUDA ID
+if [ -z "$gpu" ]; then
+ gpu="0,1,2,3"
+fi
+
+######## Train Model ###########
+echo "Experimental Name: $exp_name"
+
+# Specify the checkpoint folder (modify this path to your own)
+checkpoint_path="path/to/checkpoint/noro_checkpoint"
+
+
+# If this is a new experiment, use the following command:
+# CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
+# "${work_dir}/bins/vc/train.py" \
+# --config $exp_config \
+# --exp_name $exp_name \
+# --log_level debug
+
+# To resume training or fine-tune from a checkpoint, use the following command:
+# Ensure the options --resume, --resume_type resume, and --checkpoint_path are set
+CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
+"${work_dir}/bins/vc/Noro/train.py" \
+ --config $exp_config \
+ --exp_name $exp_name \
+ --log_level debug \
+ --resume \
+ --resume_type resume \
+ --checkpoint_path $checkpoint_path
diff --git a/egs/vc/Noro/noro_train_noisy.sh b/egs/vc/Noro/noro_train_noisy.sh
new file mode 100644
index 00000000..0c910bf6
--- /dev/null
+++ b/egs/vc/Noro/noro_train_noisy.sh
@@ -0,0 +1,55 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+######## Build Experiment Environment ###########
+exp_dir=$(cd `dirname $0`; pwd)
+work_dir=$(dirname $(dirname $exp_dir))
+
+export WORK_DIR=$work_dir
+export PYTHONPATH=$work_dir
+export PYTHONIOENCODING=UTF-8
+
+cd $work_dir/modules/monotonic_align
+mkdir -p monotonic_align
+python setup.py build_ext --inplace
+cd $work_dir
+
+if [ -z "$exp_config" ]; then
+ exp_config="${exp_dir}/exp_config_noisy.json"
+fi
+echo "Experimental Configuration File: $exp_config"
+
+# Set experiment name
+exp_name="experiment_name"
+
+# Set CUDA ID
+if [ -z "$gpu" ]; then
+ gpu="0,1,2,3"
+fi
+
+######## Train Model ###########
+echo "Experimental Name: $exp_name"
+
+# Specify the checkpoint folder (modify this path to your own)
+checkpoint_path="path/to/checkpoint/noro_checkpoint"
+
+
+# If this is a new experiment, use the following command:
+# CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
+# "${work_dir}/bins/vc/train.py" \
+# --config $exp_config \
+# --exp_name $exp_name \
+# --log_level debug
+
+# To resume training or fine-tune from a checkpoint, use the following command:
+# Ensure the options --resume, --resume_type resume, and --checkpoint_path are set
+CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
+"${work_dir}/bins/vc/train.py" \
+ --config $exp_config \
+ --exp_name $exp_name \
+ --log_level debug \
+ --resume \
+ --resume_type resume \
+ --checkpoint_path $checkpoint_path
\ No newline at end of file
diff --git a/egs/vc/README.md b/egs/vc/README.md
new file mode 100644
index 00000000..13c10477
--- /dev/null
+++ b/egs/vc/README.md
@@ -0,0 +1,20 @@
+# Amphion Singing Voice Cloning (VC) Recipe
+
+## Quick Start
+
+We provide a **[beginner recipe](Noro)** to demonstrate how to train a cutting edge SVC model. Specifically, it is an official implementation of the paper "NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities".
+
+## Supported Model Architectures
+
+Until now, Amphion has supported a noise-robust VC model with the following architecture:
+
+
+
+
+
+
+
+It has the following features:
+1. Noise-Robust Voice Conversion: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments.
+2. One-shot Voice Conversion: Achieves timbre conversion using only one reference speech sample.
+3. Speaker Representation Learning: Explores the potential of the reference encoder as a self-supervised speaker encoder.
diff --git a/imgs/vc/NoroVC.png b/imgs/vc/NoroVC.png
new file mode 100644
index 00000000..3ad47750
Binary files /dev/null and b/imgs/vc/NoroVC.png differ
diff --git a/models/tts/naturalspeech2/ns2_trainer.py b/models/tts/naturalspeech2/ns2_trainer.py
index 63c4353e..8ede54e8 100644
--- a/models/tts/naturalspeech2/ns2_trainer.py
+++ b/models/tts/naturalspeech2/ns2_trainer.py
@@ -433,26 +433,18 @@ def _train_step(self, batch):
total_loss += dur_loss
train_losses["dur_loss"] = dur_loss
- x0 = self.model.module.code_to_latent(code)
- if self.cfg.model.diffusion.diffusion_type == "diffusion":
- # diff loss x0
- diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
- total_loss += diff_loss_x0
- train_losses["diff_loss_x0"] = diff_loss_x0
-
- # diff loss noise
- diff_loss_noise = diff_loss(
- diff_out["noise_pred"], diff_out["noise"], mask=mask
- )
- total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
- train_losses["diff_loss_noise"] = diff_loss_noise
-
- elif self.cfg.model.diffusion.diffusion_type == "flow":
- # diff flow matching loss
- flow_gt = diff_out["noise"] - x0
- diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
- total_loss += diff_loss_flow
- train_losses["diff_loss_flow"] = diff_loss_flow
+ x0 = self.model.module.code_to_latent(code)
+ # diff loss x0
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
+ total_loss += diff_loss_x0
+ train_losses["diff_loss_x0"] = diff_loss_x0
+
+ # diff loss noise
+ diff_loss_noise = diff_loss(
+ diff_out["noise_pred"], diff_out["noise"], mask=mask
+ )
+ total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
+ train_losses["diff_loss_noise"] = diff_loss_noise
# diff loss ce
@@ -534,26 +526,17 @@ def _valid_step(self, batch):
valid_losses["dur_loss"] = dur_loss
x0 = self.model.module.code_to_latent(code)
- if self.cfg.model.diffusion.diffusion_type == "diffusion":
- # diff loss x0
- diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
- total_loss += diff_loss_x0
- valid_losses["diff_loss_x0"] = diff_loss_x0
-
- # diff loss noise
- diff_loss_noise = diff_loss(
- diff_out["noise_pred"], diff_out["noise"], mask=mask
- )
- total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
- valid_losses["diff_loss_noise"] = diff_loss_noise
-
- elif self.cfg.model.diffusion.diffusion_type == "flow":
- # diff flow matching loss
- flow_gt = diff_out["noise"] - x0
- diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
- total_loss += diff_loss_flow
- valid_losses["diff_loss_flow"] = diff_loss_flow
-
+ # diff loss x0
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
+ total_loss += diff_loss_x0
+ valid_losses["diff_loss_x0"] = diff_loss_x0
+
+ # diff loss noise
+ diff_loss_noise = diff_loss(
+ diff_out["noise_pred"], diff_out["noise"], mask=mask
+ )
+ total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
+ valid_losses["diff_loss_noise"] = diff_loss_noise
# diff loss ce
# (nq, B, T); (nq, B, T, 1024)
diff --git a/models/vc/Noro/noro_base_trainer.py b/models/vc/Noro/noro_base_trainer.py
new file mode 100644
index 00000000..bc343ed1
--- /dev/null
+++ b/models/vc/Noro/noro_base_trainer.py
@@ -0,0 +1,280 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import torch
+import time
+from pathlib import Path
+import torch
+import accelerate
+from accelerate.logging import get_logger
+from models.base.new_trainer import BaseTrainer
+
+
+class Noro_base_Trainer(BaseTrainer):
+ r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
+ class, and implement ``_build_model``, ``_forward_step``.
+ """
+
+ def __init__(self, args=None, cfg=None):
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # init with accelerate
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ with self.accelerator.main_process_first():
+ self.logger = get_logger(args.exp_name, log_level="INFO")
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # init counts
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check values
+ if self.accelerator.is_main_process:
+ self.__check_basic_configs()
+ # Set runtime configs
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.keep_last = [
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # set random seed
+ with self.accelerator.main_process_first():
+ # start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ # self.logger.debug(
+ # f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ # )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # setup data_loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # # save phone table to exp dir. Should be done before building model due to loading phone table in model
+ # if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon":
+ # self._save_phone_symbols_file_to_exp_path()
+
+ # setup model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
+ )
+
+ # optimizer & scheduler
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # create criterion
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterion = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume or Finetune
+ with self.accelerator.main_process_first():
+ self._check_resume()
+
+ # accelerate prepare
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self._accelerator_prepare()
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # save config file path
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+ self.device = self.accelerator.device
+
+ if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
+ self.speakers = self._build_speaker_lut()
+ self.utt2spk_dict = self._build_utt2spk_dict()
+
+ # Only for TTS tasks
+ self.task_type = "TTS"
+ self.logger.info("Task type: {}".format(self.task_type))
+
+ def _check_resume(self):
+ # if args.resume:
+ if self.args.resume or (
+ self.cfg.model_type == "VALLE" and self.args.train_stage == 2
+ ):
+ if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
+ self.args.resume_type = "finetune"
+
+ self.logger.info("Resuming from checkpoint...")
+ self.ckpt_path = self._load_model(
+ self.checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
+ )
+
+
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If a folder is given, it will
+ load the latest checkpoint in checkpoint_dir. If a path is given
+ it will load the checkpoint specified by checkpoint_path.
+ **Only use this method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None or checkpoint_path == "":
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ # example path epoch-0000_step-0017000_loss-1.972191, 找step最大的
+ checkpoint_path = max(ls, key=lambda x: int(x.split("_")[-2].split("-")[-1]))
+
+ if self.accelerator.is_main_process:
+ self.logger.info("Load model from {}".format(checkpoint_path))
+ print("Load model from {}".format(checkpoint_path))
+
+ if resume_type == "resume":
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1])
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1])
+ if isinstance(self.model, dict):
+ for idx, sub_model in enumerate(self.model.keys()):
+ try:
+ if idx == 0:
+ ckpt_name = "pytorch_model.bin"
+ else:
+ ckpt_name = "pytorch_model_{}.bin".format(idx)
+
+ self.model[sub_model].load_state_dict(
+ torch.load(os.path.join(checkpoint_path, ckpt_name))
+ )
+ except:
+ if idx == 0:
+ ckpt_name = "model.safetensors"
+ else:
+ ckpt_name = "model_{}.safetensors".format(idx)
+
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model[sub_model]),
+ os.path.join(checkpoint_path, ckpt_name),
+ )
+
+ self.model[sub_model].cuda(self.accelerator.device)
+ else:
+ try:
+ self.model.load_state_dict(
+ torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
+ )
+ if self.accelerator.is_main_process:
+ self.logger.info("Loaded 'pytorch_model.bin' for resume")
+ except:
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "model.safetensors"),
+ )
+ if self.accelerator.is_main_process:
+ self.logger.info("Loaded 'model.safetensors' for resume")
+ self.model.cuda(self.accelerator.device)
+ if self.accelerator.is_main_process:
+ self.logger.info("Load model weights SUCCESS!")
+ elif resume_type == "finetune":
+ if isinstance(self.model, dict):
+ for idx, sub_model in enumerate(self.model.keys()):
+ try:
+ if idx == 0:
+ ckpt_name = "pytorch_model.bin"
+ else:
+ ckpt_name = "pytorch_model_{}.bin".format(idx)
+
+ self.model[sub_model].load_state_dict(
+ torch.load(os.path.join(checkpoint_path, ckpt_name))
+ )
+ except:
+ if idx == 0:
+ ckpt_name = "model.safetensors"
+ else:
+ ckpt_name = "model_{}.safetensors".format(idx)
+
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model[sub_model]),
+ os.path.join(checkpoint_path, ckpt_name),
+ )
+
+ self.model[sub_model].cuda(self.accelerator.device)
+ else:
+ try:
+ self.model.load_state_dict(
+ torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
+ )
+ if self.accelerator.is_main_process:
+ self.logger.info("Loaded 'pytorch_model.bin' for finetune")
+ except:
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "model.safetensors"),
+ )
+ if self.accelerator.is_main_process:
+ self.logger.info("Loaded 'model.safetensors' for finetune")
+ self.model.cuda(self.accelerator.device)
+ if self.accelerator.is_main_process:
+ self.logger.info("Load model weights for finetune SUCCESS!")
+ else:
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
+ return checkpoint_path
+
+ def _check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
diff --git a/models/vc/Noro/noro_dataset.py b/models/vc/Noro/noro_dataset.py
new file mode 100644
index 00000000..8c858788
--- /dev/null
+++ b/models/vc/Noro/noro_dataset.py
@@ -0,0 +1,444 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import librosa
+import torch
+from torch.utils.data import Dataset
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from multiprocessing import Pool, Lock
+import random
+import torchaudio
+
+
+NUM_WORKERS = 64
+lock = Lock() # 创建一个全局锁
+SAMPLE_RATE = 16000
+
+def get_metadata(file_path):
+ metadata = torchaudio.info(file_path)
+ return file_path, metadata.num_frames
+
+def get_speaker(file_path):
+ speaker_id = file_path.split(os.sep)[-3]
+ if 'mls' in file_path:
+ speaker = 'mls_' + speaker_id
+ else:
+ speaker = 'libri_' + speaker_id
+ return file_path, speaker
+
+def safe_write_to_file(data, file_path, mode='w'):
+ try:
+ with lock, open(file_path, mode, encoding='utf-8') as f:
+ json.dump(data, f)
+ f.flush()
+ os.fsync(f.fileno())
+ except IOError as e:
+ print(f"Error writing to {file_path}: {e}")
+
+
+class VCDataset(Dataset):
+ def __init__(self, args, TRAIN_MODE=True):
+ print(f"Initializing VCDataset")
+ if TRAIN_MODE:
+ directory_list = args.directory_list
+ else:
+ directory_list = args.test_directory_list
+ random.shuffle(directory_list)
+
+ self.use_ref_noise = args.use_ref_noise
+ print(f"use_ref_noise: {self.use_ref_noise}")
+
+ # number of workers
+ print(f"Using {NUM_WORKERS} workers")
+ self.directory_list = directory_list
+ print(f"Loading {len(directory_list)} directories: {directory_list}")
+
+ # Load metadata cache
+ # metadata_cache: {file_path: num_frames}
+ self.metadata_cache_path = '/mnt/data2/hehaorui/ckpt/rp_metadata_cache.json'
+ print(f"Loading metadata_cache from {self.metadata_cache_path}")
+ if os.path.exists(self.metadata_cache_path):
+ with open(self.metadata_cache_path, 'r', encoding='utf-8') as f:
+ self.metadata_cache = json.load(f)
+ print(f"Loaded {len(self.metadata_cache)} metadata_cache")
+ else:
+ print(f"metadata_cache not found, creating new")
+ self.metadata_cache = {}
+
+ # Load speaker cache
+ # speaker_cache: {file_path: speaker}
+ self.speaker_cache_path = '/mnt/data2/hehaorui/ckpt/rp_file2speaker.json'
+ print(f"Loading speaker_cache from {self.speaker_cache_path}")
+ if os.path.exists(self.speaker_cache_path):
+ with open(self.speaker_cache_path, 'r', encoding='utf-8') as f:
+ self.speaker_cache = json.load(f)
+ print(f"Loaded {len(self.speaker_cache)} speaker_cache")
+ else:
+ print(f"speaker_cache not found, creating new")
+ self.speaker_cache = {}
+
+ self.files = []
+ # Load all flac files
+ for directory in directory_list:
+ print(f"Loading {directory}")
+ files = self.get_flac_files(directory)
+ random.shuffle(files)
+ print(f"Loaded {len(files)} files")
+ self.files.extend(files)
+ del files
+ print(f"Now {len(self.files)} files")
+ self.meta_data_cache = self.process_files()
+ self.speaker_cache = self.process_speakers()
+ temp_cache_path = self.metadata_cache_path.replace('.json', f'_{directory.split("/")[-1]}.json')
+ if not os.path.exists(temp_cache_path):
+ safe_write_to_file(self.meta_data_cache, temp_cache_path)
+ print(f"Saved metadata cache to {temp_cache_path}")
+ temp_cache_path = self.speaker_cache_path.replace('.json', f'_{directory.split("/")[-1]}.json')
+ if not os.path.exists(temp_cache_path):
+ safe_write_to_file(self.speaker_cache, temp_cache_path)
+ print(f"Saved speaker cache to {temp_cache_path}")
+
+ print(f"Loaded {len(self.files)} files")
+ random.shuffle(self.files) # Shuffle the files.
+
+ self.filtered_files, self.all_num_frames, index2numframes, index2speakerid = self.filter_files()
+ print(f"Loaded {len(self.filtered_files)} files")
+
+ self.index2numframes = index2numframes
+ self.index2speaker = index2speakerid
+ self.speaker2id = self.create_speaker2id()
+ self.num_frame_sorted = np.array(sorted(self.all_num_frames))
+ self.num_frame_indices = np.array(
+ sorted(
+ range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k]
+ )
+ )
+ del self.meta_data_cache, self.speaker_cache
+
+ if self.use_ref_noise:
+ if TRAIN_MODE:
+ self.noise_filenames = self.get_all_flac(args.noise_dir)
+ else:
+ self.noise_filenames = self.get_all_flac(args.test_noise_dir)
+
+ def process_files(self):
+ print(f"Processing metadata...")
+ files_to_process = [file for file in self.files if file not in self.metadata_cache]
+ if files_to_process:
+ with Pool(processes=NUM_WORKERS) as pool:
+ results = list(tqdm(pool.imap_unordered(get_metadata, files_to_process), total=len(files_to_process)))
+ for file, num_frames in results:
+ self.metadata_cache[file] = num_frames
+ safe_write_to_file(self.metadata_cache, self.metadata_cache_path)
+ else:
+ print(f"Skipping processing metadata, loaded {len(self.metadata_cache)} files")
+ return self.metadata_cache
+
+ def process_speakers(self):
+ print(f"Processing speakers...")
+ files_to_process = [file for file in self.files if file not in self.speaker_cache]
+ if files_to_process:
+ with Pool(processes=NUM_WORKERS) as pool:
+ results = list(tqdm(pool.imap_unordered(get_speaker, files_to_process), total=len(files_to_process)))
+ for file, speaker in results:
+ self.speaker_cache[file] = speaker
+ safe_write_to_file(self.speaker_cache, self.speaker_cache_path)
+ else:
+ print(f"Skipping processing speakers, loaded {len(self.speaker_cache)} files")
+ return self.speaker_cache
+
+ def get_flac_files(self, directory):
+ flac_files = []
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ # flac or wav
+ if file.endswith(".flac") or file.endswith(".wav"):
+ flac_files.append(os.path.join(root, file))
+ return flac_files
+
+ def get_all_flac(self, directory):
+ directories = [os.path.join(directory, d) for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
+ if not directories:
+ return self.get_flac_files(directory)
+ with Pool(processes=NUM_WORKERS) as pool:
+ results = []
+ for result in tqdm(pool.imap_unordered(self.get_flac_files, directories), total=len(directories), desc="Processing"):
+ results.extend(result)
+ print(f"Found {len(results)} waveform files")
+ return results
+
+ def get_num_frames(self, index):
+ return self.index2numframes[index]
+
+ def filter_files(self):
+ # Filter files
+ metadata_cache = self.meta_data_cache
+ speaker_cache = self.speaker_cache
+ filtered_files = []
+ all_num_frames = []
+ index2numframes = {}
+ index2speaker = {}
+ for file in self.files:
+ num_frames = metadata_cache[file]
+ if SAMPLE_RATE * 3 <= num_frames <= SAMPLE_RATE * 30:
+ filtered_files.append(file)
+ all_num_frames.append(num_frames)
+ index2speaker[len(filtered_files) - 1] = speaker_cache[file]
+ index2numframes[len(filtered_files) - 1] = num_frames
+ return filtered_files, all_num_frames, index2numframes, index2speaker
+
+ def create_speaker2id(self):
+ speaker2id = {}
+ unique_id = 0
+ print(f"Creating speaker2id from {len(self.index2speaker)} utterences")
+ for _, speaker in tqdm(self.index2speaker.items()):
+ if speaker not in speaker2id:
+ speaker2id[speaker] = unique_id
+ unique_id += 1
+ print(f"Created speaker2id with {len(speaker2id)} speakers")
+ return speaker2id
+
+ def snr_mixer(self, clean, noise, snr):
+ # Normalizing to -25 dB FS
+ rmsclean = (clean**2).mean()**0.5
+ epsilon = 1e-10
+ rmsclean = max(rmsclean, epsilon)
+ scalarclean = 10 ** (-25 / 20) / rmsclean
+ clean = clean * scalarclean
+
+ rmsnoise = (noise**2).mean()**0.5
+ scalarnoise = 10 ** (-25 / 20) /rmsnoise
+ noise = noise * scalarnoise
+ rmsnoise = (noise**2).mean()**0.5
+
+ # Set the noise level for a given SNR
+ noisescalar = np.sqrt(rmsclean / (10**(snr/20)) / rmsnoise)
+ noisenewlevel = noise * noisescalar
+ noisyspeech = clean + noisenewlevel
+ noisyspeech_tensor = torch.tensor(noisyspeech, dtype=torch.float32)
+ return noisyspeech_tensor
+
+ def add_noise(self, clean):
+ # self.noise_filenames: list of noise files
+ random_idx = np.random.randint(0, np.size(self.noise_filenames))
+ noise, _ = librosa.load(self.noise_filenames[random_idx], sr=SAMPLE_RATE)
+ clean = clean.cpu().numpy()
+ if len(noise)>=len(clean):
+ noise = noise[0:len(clean)]
+ else:
+ while len(noise)<=len(clean):
+ random_idx = (random_idx + 1)%len(self.noise_filenames)
+ newnoise, fs = librosa.load(self.noise_filenames[random_idx], sr=SAMPLE_RATE)
+ noiseconcat = np.append(noise, np.zeros(int(fs * 0.2)))
+ noise = np.append(noiseconcat, newnoise)
+ noise = noise[0:len(clean)]
+ snr = random.uniform(0.0,20.0)
+ noisyspeech = self.snr_mixer(clean=clean, noise=noise, snr=snr)
+ del noise
+ return noisyspeech
+
+ def __len__(self):
+ return len(self.files)
+
+ def __getitem__(self, idx):
+ file_path = self.filtered_files[idx]
+ speech, _ = librosa.load(file_path, sr=SAMPLE_RATE)
+ if len(speech) > 30 * SAMPLE_RATE:
+ speech = speech[:30 * SAMPLE_RATE]
+ speech = torch.tensor(speech, dtype=torch.float32)
+ # inputs = self._get_reference_vc(speech, hop_length=320)
+ inputs = self._get_reference_vc(speech, hop_length=200)
+ speaker = self.index2speaker[idx]
+ speaker_id = self.speaker2id[speaker]
+ inputs["speaker_id"] = speaker_id
+ return inputs
+
+ def _get_reference_vc(self, speech, hop_length):
+ pad_size = 1600 - speech.shape[0] % 1600
+ speech = torch.nn.functional.pad(speech, (0, pad_size))
+
+ #hop_size
+ frame_nums = speech.shape[0] // hop_length
+ clip_frame_nums = np.random.randint(int(frame_nums * 0.25), int(frame_nums * 0.45))
+ clip_frame_nums += (frame_nums - clip_frame_nums) % 8
+ start_frames, end_frames = 0, clip_frame_nums
+
+ ref_speech = speech[start_frames * hop_length : end_frames * hop_length]
+ new_speech = torch.cat((speech[:start_frames * hop_length], speech[end_frames * hop_length:]), 0)
+
+ ref_mask = torch.ones(len(ref_speech) // hop_length)
+ mask = torch.ones(len(new_speech) // hop_length)
+ if not self.use_ref_noise:
+ # not use noise
+ return {"speech": new_speech, "ref_speech": ref_speech, "ref_mask": ref_mask, "mask": mask}
+ else:
+ # use reference noise
+ noisy_ref_speech = self.add_noise(ref_speech)
+ return {"speech": new_speech, "ref_speech": ref_speech, "noisy_ref_speech": noisy_ref_speech, "ref_mask": ref_mask, "mask": mask}
+
+
+class BaseCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "phone_len":
+ packed_batch_features["phone_len"] = torch.LongTensor(
+ [b["phone_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["phn_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "audio_len":
+ packed_batch_features["audio_len"] = torch.LongTensor(
+ [b["audio_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
+ ]
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ return packed_batch_features
+
+class VCCollator(BaseCollator):
+ def __init__(self, cfg):
+ BaseCollator.__init__(self, cfg)
+ #self.use_noise = cfg.trans_exp.use_noise
+
+ self.use_ref_noise = self.cfg.trans_exp.use_ref_noise
+ print(f"use_ref_noise: {self.use_ref_noise}")
+
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # Function to handle tensor copying
+ def process_tensor(data, dtype=torch.float32):
+ if isinstance(data, torch.Tensor):
+ return data.clone().detach()
+ else:
+ return torch.tensor(data, dtype=dtype)
+
+ # Process 'speech' data
+ speeches = [process_tensor(b['speech']) for b in batch]
+ packed_batch_features['speech'] = pad_sequence(speeches, batch_first=True, padding_value=0)
+
+ # Process 'ref_speech' data
+ ref_speeches = [process_tensor(b['ref_speech']) for b in batch]
+ packed_batch_features['ref_speech'] = pad_sequence(ref_speeches, batch_first=True, padding_value=0)
+
+ # Process 'mask' data
+ masks = [process_tensor(b['mask']) for b in batch]
+ packed_batch_features['mask'] = pad_sequence(masks, batch_first=True, padding_value=0)
+
+ # Process 'ref_mask' data
+ ref_masks = [process_tensor(b['ref_mask']) for b in batch]
+ packed_batch_features['ref_mask'] = pad_sequence(ref_masks, batch_first=True, padding_value=0)
+
+ # Process 'speaker_id' data
+ speaker_ids = [process_tensor(b['speaker_id'], dtype=torch.int64) for b in batch]
+ packed_batch_features['speaker_id'] = torch.stack(speaker_ids, dim=0)
+ if self.use_ref_noise:
+ # Process 'noisy_ref_speech' data
+ noisy_ref_speeches = [process_tensor(b['noisy_ref_speech']) for b in batch]
+ packed_batch_features['noisy_ref_speech'] = pad_sequence(noisy_ref_speeches, batch_first=True, padding_value=0)
+ return packed_batch_features
+
+
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ if len(batch) == 0:
+ return 0
+ if len(batch) == max_sentences:
+ return 1
+ if num_tokens > max_tokens:
+ return 1
+ return 0
+
+
+def batch_by_size(
+ indices,
+ num_tokens_fn,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+):
+ """
+ Yield mini-batches of indices bucketed by size. Batches may contain
+ sequences of different lengths.
+
+ Args:
+ indices (List[int]): ordered list of dataset indices
+ num_tokens_fn (callable): function that returns the number of tokens at
+ a given index
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ """
+ bsz_mult = required_batch_size_multiple
+
+ sample_len = 0
+ sample_lens = []
+ batch = []
+ batches = []
+ for i in range(len(indices)):
+ idx = indices[i]
+ num_tokens = num_tokens_fn(idx)
+ sample_lens.append(num_tokens)
+ sample_len = max(sample_len, num_tokens)
+
+ assert (
+ sample_len <= max_tokens
+ ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
+ idx, sample_len, max_tokens
+ )
+ num_tokens = (len(batch) + 1) * sample_len
+
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ mod_len = max(
+ bsz_mult * (len(batch) // bsz_mult),
+ len(batch) % bsz_mult,
+ )
+ batches.append(batch[:mod_len])
+ batch = batch[mod_len:]
+ sample_lens = sample_lens[mod_len:]
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+ batch.append(idx)
+ if len(batch) > 0:
+ batches.append(batch)
+ return batches
diff --git a/models/vc/Noro/noro_inference.py b/models/vc/Noro/noro_inference.py
new file mode 100644
index 00000000..c4eab1e2
--- /dev/null
+++ b/models/vc/Noro/noro_inference.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import torch
+import numpy as np
+import librosa
+from safetensors.torch import load_model
+import os
+from utils.util import load_config
+from models.vc.Noro.noro_trainer import NoroTrainer
+from models.vc.Noro.noro_model import Noro_VCmodel
+from processors.content_extractor import HubertExtractor
+from utils.mel import mel_spectrogram_torch
+from utils.f0 import get_f0_features_using_dio, interpolate
+
+def build_trainer(args, cfg):
+ supported_trainer = {
+ "VC": NoroTrainer,
+ }
+ trainer_class = supported_trainer[cfg.model_type]
+ trainer = trainer_class(args, cfg)
+ return trainer
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config",
+ default="config.json",
+ help="JSON file for configurations.",
+ required=True,
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ help="Checkpoint for resume training or fine-tuning.",
+ required=True,
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Output path",
+ required=True,
+ )
+ parser.add_argument(
+ "--ref_path",
+ type=str,
+ help="Reference voice path",
+ )
+ parser.add_argument(
+ "--source_path",
+ type=str,
+ help="Source voice path",
+ )
+ parser.add_argument(
+ "--cuda_id",
+ type=int,
+ default=0,
+ help="CUDA id for training."
+ )
+
+ parser.add_argument("--local_rank", default=-1, type=int)
+ args = parser.parse_args()
+ cfg = load_config(args.config)
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ cuda_id = args.cuda_id
+ args.local_rank = torch.device(f"cuda:{cuda_id}")
+ print("Local rank:", args.local_rank)
+
+ args.content_extractor = "mhubert"
+
+ with torch.cuda.device(args.local_rank):
+ torch.cuda.empty_cache()
+ ckpt_path = args.checkpoint_path
+
+ w2v = HubertExtractor(cfg)
+ w2v = w2v.to(device=args.local_rank)
+ w2v.eval()
+
+ model = Noro_VCmodel(cfg=cfg.model)
+ print("Loading model")
+
+ load_model(model, ckpt_path)
+ print("Model loaded")
+ model.cuda(args.local_rank)
+ model.eval()
+
+ wav_path = args.source_path
+ ref_wav_path = args.ref_path
+
+ wav, _ = librosa.load(wav_path, sr=16000)
+ wav = np.pad(wav, (0, 1600 - len(wav) % 1600))
+ audio = torch.from_numpy(wav).to(args.local_rank)
+ audio = audio[None, :]
+
+ ref_wav, _ = librosa.load(ref_wav_path, sr=16000)
+ ref_wav = np.pad(ref_wav, (0, 200 - len(ref_wav) % 200))
+ ref_audio = torch.from_numpy(ref_wav).to(args.local_rank)
+ ref_audio = ref_audio[None, :]
+
+ with torch.no_grad():
+ ref_mel = mel_spectrogram_torch(ref_audio, cfg)
+ ref_mel = ref_mel.transpose(1, 2).to(device=args.local_rank)
+ ref_mask = torch.ones(ref_mel.shape[0], ref_mel.shape[1]).to(args.local_rank).bool()
+
+ _, content_feature = w2v.extract_content_features(audio)
+ content_feature = content_feature.to(device=args.local_rank)
+
+ wav = audio.cpu().numpy()
+ pitch_raw = get_f0_features_using_dio(wav, cfg)
+ pitch_raw, _ = interpolate(pitch_raw)
+ frame_num = len(wav) // cfg.preprocess.hop_size
+ pitch_raw = torch.from_numpy(pitch_raw[:frame_num]).float()
+ pitch = (pitch_raw - pitch_raw.mean(dim=1, keepdim=True)) / (
+ pitch_raw.std(dim=1, keepdim=True) + 1e-6
+ )
+
+ x0 = model.inference(
+ content_feature=content_feature,
+ pitch=pitch,
+ x_ref=ref_mel,
+ x_ref_mask=ref_mask,
+ inference_steps=200,
+ sigma=1.2,
+ ) # 150-300 0.95-1.5
+
+ recon_path = f"{args.output_dir}/recon_mel.npy"
+ np.save(recon_path, x0.transpose(1, 2).detach().cpu().numpy())
+ print(f"Mel spectrogram saved to: {recon_path}")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/models/vc/Noro/noro_loss.py b/models/vc/Noro/noro_loss.py
new file mode 100644
index 00000000..257c47e8
--- /dev/null
+++ b/models/vc/Noro/noro_loss.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def cross_entropy_loss(preds, targets, reduction='none'):
+ log_softmax = nn.LogSoftmax(dim=-1)
+ loss = (-targets * log_softmax(preds)).sum(1)
+ if reduction == "none":
+ return loss
+ elif reduction == "mean":
+ return loss.mean()
+
+class ConstractiveSpeakerLoss(nn.Module):
+ def __init__(self, temperature=1.):
+ super(ConstractiveSpeakerLoss, self).__init__()
+ self.temperature = temperature
+
+ def forward(self, x, speaker_ids):
+ # x : B, H
+ # speaker_ids: B 3 4 3
+ speaker_ids = speaker_ids.reshape(-1)
+ speaker_ids_expand = torch.zeros(len(speaker_ids),len(speaker_ids)).to(speaker_ids.device)
+ speaker_ids_expand = (speaker_ids.view(-1,1) == speaker_ids).float()
+ x_t = x.transpose(0,1) # B, C --> C,B
+ logits = (x @ x_t) / self.temperature # B, H * H, B --> B, B
+ targets = F.softmax(speaker_ids_expand / self.temperature, dim=-1)
+ loss = cross_entropy_loss(logits, targets, reduction='none')
+ return loss.mean()
+
+def diff_loss(pred, target, mask, loss_type="l1"):
+ # pred: (B, T, d)
+ # target: (B, T, d)
+ # mask: (B, T)
+ if loss_type == "l1":
+ loss = F.l1_loss(pred, target, reduction="none").float() * (
+ mask.to(pred.dtype).unsqueeze(-1)
+ )
+ elif loss_type == "l2":
+ loss = F.mse_loss(pred, target, reduction="none").float() * (
+ mask.to(pred.dtype).unsqueeze(-1)
+ )
+ else:
+ raise NotImplementedError()
+ loss = (torch.mean(loss, dim=-1)).sum() / (mask.to(pred.dtype).sum())
+ return loss
\ No newline at end of file
diff --git a/models/vc/Noro/noro_model.py b/models/vc/Noro/noro_model.py
new file mode 100644
index 00000000..0b0bbbf4
--- /dev/null
+++ b/models/vc/Noro/noro_model.py
@@ -0,0 +1,1347 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+import math
+import json5
+from librosa.filters import mel as librosa_mel_fn
+from einops.layers.torch import Rearrange
+
+sr = 16000
+MAX_WAV_VALUE = 32768.0
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+init_mel_and_hann = False
+
+
+def mel_spectrogram(
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
+):
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ global mel_basis, hann_window, init_mel_and_hann
+ if not init_mel_and_hann:
+ mel = librosa_mel_fn(
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+ )
+ mel_basis[str(fmax) + "_" + str(y.device)] = (
+ torch.from_numpy(mel).float().to(y.device)
+ )
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+ init_mel_and_hann = True
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[str(y.device)],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.view_as_real(spec)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
+
+
+class Diffusion(nn.Module):
+ def __init__(self, cfg, diff_model):
+ super().__init__()
+
+ self.cfg = cfg
+ self.diff_estimator = diff_model
+ self.beta_min = cfg.beta_min
+ self.beta_max = cfg.beta_max
+ self.sigma = cfg.sigma
+ self.noise_factor = cfg.noise_factor
+
+ def forward(self, x, condition_embedding, x_mask, reference_embedding, offset=1e-5):
+ diffusion_step = torch.rand(
+ x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False
+ )
+ diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset)
+ xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step)
+
+ cum_beta = self.get_cum_beta(diffusion_step.unsqueeze(-1).unsqueeze(-1))
+ x0_pred = self.diff_estimator(
+ xt, condition_embedding, x_mask, reference_embedding, diffusion_step
+ )
+ mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
+ variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
+ noise_pred = (xt - mean_pred) / (torch.sqrt(variance) * self.noise_factor)
+ noise = z
+ diff_out = {"x0_pred": x0_pred, "noise_pred": noise_pred, "noise": noise}
+ return diff_out
+
+ @torch.no_grad()
+ def get_cum_beta(self, time_step):
+ return self.beta_min * time_step + 0.5 * (self.beta_max - self.beta_min) * (
+ time_step**2
+ )
+
+ @torch.no_grad()
+ def get_beta_t(self, time_step):
+ return self.beta_min + (self.beta_max - self.beta_min) * time_step
+
+ @torch.no_grad()
+ def forward_diffusion(self, x0, diffusion_step):
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
+ cum_beta = self.get_cum_beta(time_step)
+ mean = x0 * torch.exp(-0.5 * cum_beta / (self.sigma**2))
+ variance = (self.sigma**2) * (1 - torch.exp(-cum_beta / (self.sigma**2)))
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
+ xt = mean + z * torch.sqrt(variance) * self.noise_factor
+ return xt, z
+
+ @torch.no_grad()
+ def cal_dxt(
+ self, xt, condition_embedding, x_mask, reference_embedding, diffusion_step, h
+ ):
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
+ cum_beta = self.get_cum_beta(time_step=time_step)
+ beta_t = self.get_beta_t(time_step=time_step)
+ x0_pred = self.diff_estimator(
+ xt, condition_embedding, x_mask, reference_embedding, diffusion_step
+ )
+ mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
+ noise_pred = xt - mean_pred
+ variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
+ logp = -noise_pred / (variance + 1e-8)
+ dxt = -0.5 * h * beta_t * (logp + xt / (self.sigma**2))
+ return dxt
+
+ @torch.no_grad()
+ def reverse_diffusion(
+ self, z, condition_embedding, x_mask, reference_embedding, n_timesteps
+ ):
+ h = 1.0 / max(n_timesteps, 1)
+ xt = z
+ for i in range(n_timesteps):
+ t = (1.0 - (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ dxt = self.cal_dxt(
+ xt,
+ condition_embedding,
+ x_mask,
+ reference_embedding,
+ diffusion_step=t,
+ h=h,
+ )
+ xt_ = xt - dxt
+ if self.cfg.ode_solve_method == "midpoint":
+ x_mid = 0.5 * (xt_ + xt)
+ dxt = self.cal_dxt(
+ x_mid,
+ condition_embedding,
+ x_mask,
+ reference_embedding,
+ diffusion_step=t + 0.5 * h,
+ h=h,
+ )
+ xt = xt - dxt
+ elif self.cfg.ode_solve_method == "euler":
+ xt = xt_
+ return xt
+
+ @torch.no_grad()
+ def reverse_diffusion_from_t(
+ self, z, condition_embedding, x_mask, reference_embedding, n_timesteps, t_start
+ ):
+ h = t_start / max(n_timesteps, 1)
+ xt = z
+ for i in range(n_timesteps):
+ t = (t_start - (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ dxt = self.cal_dxt(
+ xt,
+ x_mask,
+ condition_embedding,
+ x_mask,
+ reference_embedding,
+ diffusion_step=t,
+ h=h,
+ )
+ xt_ = xt - dxt
+ if self.cfg.ode_solve_method == "midpoint":
+ x_mid = 0.5 * (xt_ + xt)
+ dxt = self.cal_dxt(
+ x_mid,
+ condition_embedding,
+ x_mask,
+ reference_embedding,
+ diffusion_step=t + 0.5 * h,
+ h=h,
+ )
+ xt = xt - dxt
+ elif self.cfg.ode_solve_method == "euler":
+ xt = xt_
+ return xt
+
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class Linear2(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+ self.linear_1 = nn.Linear(dim, dim * 2)
+ self.linear_2 = nn.Linear(dim * 2, dim)
+ self.linear_1.weight.data.normal_(0.0, 0.02)
+ self.linear_2.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x):
+ x = self.linear_1(x)
+ x = F.silu(x)
+ x = self.linear_2(x)
+ return x
+
+
+class StyleAdaptiveLayerNorm(nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5):
+ super().__init__()
+ self.in_dim = normalized_shape
+ self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
+ self.style = nn.Linear(self.in_dim, self.in_dim * 2)
+ self.style.bias.data[: self.in_dim] = 1
+ self.style.bias.data[self.in_dim :] = 0
+
+ def forward(self, x, condition):
+ # x: (B, T, d); condition: (B, T, d)
+
+ style = self.style(torch.mean(condition, dim=1, keepdim=True))
+
+ gamma, beta = style.chunk(2, -1)
+
+ out = self.norm(x)
+
+ out = gamma * out + beta
+ return out
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout, max_len=5000):
+ super().__init__()
+
+ self.dropout = dropout
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+ )
+ pe = torch.zeros(max_len, 1, d_model)
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ x = x + self.pe[: x.size(0)]
+ return F.dropout(x, self.dropout, training=self.training)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
+ ):
+ super().__init__()
+
+ self.encoder_hidden = encoder_hidden
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+
+ self.ffn_1 = nn.Conv1d(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ padding=self.conv_kernel_size // 2,
+ )
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x):
+ # x: (B, T, d)
+ x = self.ffn_1(x.permute(0, 2, 1)).permute(
+ 0, 2, 1
+ ) # (B, T, d) -> (B, d, T) -> (B, T, d)
+ x = F.silu(x)
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+
+class TransformerFFNLayerOld(nn.Module):
+ def __init__(
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
+ ):
+ super().__init__()
+
+ self.encoder_hidden = encoder_hidden
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+
+ self.ffn_1 = nn.Linear(self.encoder_hidden, self.conv_filter_size)
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x):
+ x = self.ffn_1(x)
+ x = F.silu(x)
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ encoder_hidden,
+ encoder_head,
+ conv_filter_size,
+ conv_kernel_size,
+ encoder_dropout,
+ use_cln,
+ use_skip_connection,
+ use_new_ffn,
+ add_diff_step,
+ ):
+ super().__init__()
+ self.encoder_hidden = encoder_hidden
+ self.encoder_head = encoder_head
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+ self.use_cln = use_cln
+ self.use_skip_connection = use_skip_connection
+ self.use_new_ffn = use_new_ffn
+ self.add_diff_step = add_diff_step
+
+ if not self.use_cln:
+ self.ln_1 = nn.LayerNorm(self.encoder_hidden)
+ self.ln_2 = nn.LayerNorm(self.encoder_hidden)
+ else:
+ self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
+ self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
+
+ self.self_attn = nn.MultiheadAttention(
+ self.encoder_hidden, self.encoder_head, batch_first=True
+ )
+
+ if self.use_new_ffn:
+ self.ffn = TransformerFFNLayer(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ )
+ else:
+ self.ffn = TransformerFFNLayerOld(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ )
+
+ if self.use_skip_connection:
+ self.skip_linear = nn.Linear(self.encoder_hidden * 2, self.encoder_hidden)
+ self.skip_linear.weight.data.normal_(0.0, 0.02)
+ self.skip_layernorm = nn.LayerNorm(self.encoder_hidden)
+
+ if self.add_diff_step:
+ self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden)
+ # self.diff_step_projection = nn.linear(self.encoder_hidden, self.encoder_hidden)
+ # self.encoder_hidden.weight.data.normal_(0.0, 0.02)
+ self.diff_step_projection = Linear2(self.encoder_hidden)
+
+ def forward(
+ self, x, key_padding_mask, conditon=None, skip_res=None, diffusion_step=None
+ ):
+ # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d); skip_res: (B, T, d); diffusion_step: (B,)
+
+ if self.use_skip_connection and skip_res != None:
+ x = torch.cat([x, skip_res], dim=-1) # (B, T, 2*d)
+ x = self.skip_linear(x)
+ x = self.skip_layernorm(x)
+
+ if self.add_diff_step and diffusion_step != None:
+ diff_step_embedding = self.diff_step_emb(diffusion_step)
+ diff_step_embedding = self.diff_step_projection(diff_step_embedding)
+ x = x + diff_step_embedding.unsqueeze(1)
+
+ residual = x
+
+ # pre norm
+ if self.use_cln:
+ x = self.ln_1(x, conditon)
+ else:
+ x = self.ln_1(x)
+
+ # self attention
+ if key_padding_mask != None:
+ key_padding_mask_input = ~(key_padding_mask.bool())
+ else:
+ key_padding_mask_input = None
+ x, _ = self.self_attn(
+ query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
+ )
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+
+ x = residual + x
+
+ # pre norm
+ residual = x
+ if self.use_cln:
+ x = self.ln_2(x, conditon)
+ else:
+ x = self.ln_2(x)
+
+ # ffn
+ x = self.ffn(x)
+
+ x = residual + x
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ enc_emb_tokens=None,
+ encoder_layer=None,
+ encoder_hidden=None,
+ encoder_head=None,
+ conv_filter_size=None,
+ conv_kernel_size=None,
+ encoder_dropout=None,
+ use_cln=None,
+ use_skip_connection=None,
+ use_new_ffn=None,
+ add_diff_step=None,
+ cfg=None,
+ ):
+ super().__init__()
+
+ self.encoder_layer = (
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
+ )
+ self.encoder_hidden = (
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
+ )
+ self.encoder_head = (
+ encoder_head if encoder_head is not None else cfg.encoder_head
+ )
+ self.conv_filter_size = (
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
+ )
+ self.conv_kernel_size = (
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
+ )
+ self.encoder_dropout = (
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
+ )
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
+ self.use_skip_connection = (
+ use_skip_connection
+ if use_skip_connection is not None
+ else cfg.use_skip_connection
+ )
+ self.add_diff_step = (
+ add_diff_step if add_diff_step is not None else cfg.add_diff_step
+ )
+ self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn
+
+ if enc_emb_tokens != None:
+ self.use_enc_emb = True
+ self.enc_emb_tokens = enc_emb_tokens
+ else:
+ self.use_enc_emb = False
+
+ self.position_emb = PositionalEncoding(
+ self.encoder_hidden, self.encoder_dropout
+ )
+
+ self.layers = nn.ModuleList([])
+ if self.use_skip_connection:
+ self.layers.extend(
+ [
+ TransformerEncoderLayer(
+ self.encoder_hidden,
+ self.encoder_head,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ self.use_cln,
+ use_skip_connection=False,
+ use_new_ffn=self.use_new_ffn,
+ add_diff_step=self.add_diff_step,
+ )
+ for i in range(
+ (self.encoder_layer + 1) // 2
+ ) # for example: 12 -> 6; 13 -> 7
+ ]
+ )
+ self.layers.extend(
+ [
+ TransformerEncoderLayer(
+ self.encoder_hidden,
+ self.encoder_head,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ self.use_cln,
+ use_skip_connection=True,
+ use_new_ffn=self.use_new_ffn,
+ add_diff_step=self.add_diff_step,
+ )
+ for i in range(
+ self.encoder_layer - (self.encoder_layer + 1) // 2
+ ) # 12 -> 6; 13 -> 6
+ ]
+ )
+ else:
+ self.layers.extend(
+ [
+ TransformerEncoderLayer(
+ self.encoder_hidden,
+ self.encoder_head,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ self.use_cln,
+ use_new_ffn=self.use_new_ffn,
+ add_diff_step=self.add_diff_step,
+ use_skip_connection=False,
+ )
+ for i in range(self.encoder_layer)
+ ]
+ )
+
+ if self.use_cln:
+ self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
+ else:
+ self.last_ln = nn.LayerNorm(self.encoder_hidden)
+
+ if self.add_diff_step:
+ self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden)
+ # self.diff_step_projection = nn.linear(self.encoder_hidden, self.encoder_hidden)
+ # self.encoder_hidden.weight.data.normal_(0.0, 0.02)
+ self.diff_step_projection = Linear2(self.encoder_hidden)
+
+ def forward(self, x, key_padding_mask, condition=None, diffusion_step=None):
+ if len(x.shape) == 2 and self.use_enc_emb:
+ x = self.enc_emb_tokens(x)
+ x = self.position_emb(x)
+ else:
+ x = self.position_emb(x) # (B, T, d)
+
+ if self.add_diff_step and diffusion_step != None:
+ diff_step_embedding = self.diff_step_emb(diffusion_step)
+ diff_step_embedding = self.diff_step_projection(diff_step_embedding)
+ x = x + diff_step_embedding.unsqueeze(1)
+
+ if self.use_skip_connection:
+ skip_res_list = []
+ # down
+ for layer in self.layers[: self.encoder_layer // 2]:
+ x = layer(x, key_padding_mask, condition)
+ res = x
+ skip_res_list.append(res)
+ # middle
+ for layer in self.layers[
+ self.encoder_layer // 2 : (self.encoder_layer + 1) // 2
+ ]:
+ x = layer(x, key_padding_mask, condition)
+ # up
+ for layer in self.layers[(self.encoder_layer + 1) // 2 :]:
+ skip_res = skip_res_list.pop()
+ x = layer(x, key_padding_mask, condition, skip_res)
+ else:
+ for layer in self.layers:
+ x = layer(x, key_padding_mask, condition)
+
+ if self.use_cln:
+ x = self.last_ln(x, condition)
+ else:
+ x = self.last_ln(x)
+
+ return x
+
+
+class DiffTransformer(nn.Module):
+ def __init__(
+ self,
+ encoder_layer=None,
+ encoder_hidden=None,
+ encoder_head=None,
+ conv_filter_size=None,
+ conv_kernel_size=None,
+ encoder_dropout=None,
+ use_cln=None,
+ use_skip_connection=None,
+ use_new_ffn=None,
+ add_diff_step=None,
+ cat_diff_step=None,
+ in_dim=None,
+ out_dim=None,
+ cond_dim=None,
+ cfg=None,
+ ):
+ super().__init__()
+
+ self.encoder_layer = (
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
+ )
+ self.encoder_hidden = (
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
+ )
+ self.encoder_head = (
+ encoder_head if encoder_head is not None else cfg.encoder_head
+ )
+ self.conv_filter_size = (
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
+ )
+ self.conv_kernel_size = (
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
+ )
+ self.encoder_dropout = (
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
+ )
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
+ self.use_skip_connection = (
+ use_skip_connection
+ if use_skip_connection is not None
+ else cfg.use_skip_connection
+ )
+ self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn
+ self.add_diff_step = (
+ add_diff_step if add_diff_step is not None else cfg.add_diff_step
+ )
+ self.cat_diff_step = (
+ cat_diff_step if cat_diff_step is not None else cfg.cat_diff_step
+ )
+ self.in_dim = in_dim if in_dim is not None else cfg.in_dim
+ self.out_dim = out_dim if out_dim is not None else cfg.out_dim
+ self.cond_dim = cond_dim if cond_dim is not None else cfg.cond_dim
+
+ if self.in_dim != self.encoder_hidden:
+ self.in_linear = nn.Linear(self.in_dim, self.encoder_hidden)
+ self.in_linear.weight.data.normal_(0.0, 0.02)
+ else:
+ self.in_dim = None
+
+ if self.out_dim != self.encoder_hidden:
+ self.out_linear = nn.Linear(self.encoder_hidden, self.out_dim)
+ self.out_linear.weight.data.normal_(0.0, 0.02)
+ else:
+ self.out_dim = None
+
+ assert not ((self.cat_diff_step == True) and (self.add_diff_step == True))
+
+ self.transformer_encoder = TransformerEncoder(
+ encoder_layer=self.encoder_layer,
+ encoder_hidden=self.encoder_hidden,
+ encoder_head=self.encoder_head,
+ conv_kernel_size=self.conv_kernel_size,
+ conv_filter_size=self.conv_filter_size,
+ encoder_dropout=self.encoder_dropout,
+ use_cln=self.use_cln,
+ use_skip_connection=self.use_skip_connection,
+ use_new_ffn=self.use_new_ffn,
+ add_diff_step=self.add_diff_step,
+ )
+
+ self.cond_project = nn.Linear(self.cond_dim, self.encoder_hidden)
+ self.cond_project.weight.data.normal_(0.0, 0.02)
+ self.cat_linear = nn.Linear(self.encoder_hidden * 2, self.encoder_hidden)
+ self.cat_linear.weight.data.normal_(0.0, 0.02)
+
+ if self.cat_diff_step:
+ self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden)
+ self.diff_step_projection = Linear2(self.encoder_hidden)
+
+ def forward(
+ self,
+ x,
+ condition_embedding,
+ key_padding_mask=None,
+ reference_embedding=None,
+ diffusion_step=None,
+ ):
+ # x: shape is (B, T, d_x)
+ # key_padding_mask: shape is (B, T), mask is 0
+ # condition_embedding: from condition adapter, shape is (B, T, d_c)
+ # reference_embedding: from reference encoder, shape is (B, N, d_r), or (B, 1, d_r), or (B, d_r)
+
+ if self.in_linear != None:
+ x = self.in_linear(x)
+ condition_embedding = self.cond_project(condition_embedding)
+
+ x = torch.cat([x, condition_embedding], dim=-1)
+ x = self.cat_linear(x)
+
+ if self.cat_diff_step and diffusion_step != None:
+ diff_step_embedding = self.diff_step_emb(diffusion_step)
+ diff_step_embedding = self.diff_step_projection(
+ diff_step_embedding
+ ).unsqueeze(
+ 1
+ ) # (B, 1, d)
+ x = torch.cat([diff_step_embedding, x], dim=1)
+ if key_padding_mask != None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.ones(key_padding_mask.shape[0], 1).to(
+ key_padding_mask.device
+ ),
+ ],
+ dim=1,
+ )
+
+ x = self.transformer_encoder(
+ x,
+ key_padding_mask=key_padding_mask,
+ condition=reference_embedding,
+ diffusion_step=diffusion_step,
+ )
+
+ if self.cat_diff_step and diffusion_step != None:
+ x = x[:, 1:, :]
+
+ if self.out_linear != None:
+ x = self.out_linear(x)
+
+ return x
+
+
+class ReferenceEncoder(nn.Module):
+ def __init__(
+ self,
+ encoder_layer=None,
+ encoder_hidden=None,
+ encoder_head=None,
+ conv_filter_size=None,
+ conv_kernel_size=None,
+ encoder_dropout=None,
+ use_skip_connection=None,
+ use_new_ffn=None,
+ ref_in_dim=None,
+ ref_out_dim=None,
+ use_query_emb=None,
+ num_query_emb=None,
+ cfg=None,
+ ):
+ super().__init__()
+
+ self.encoder_layer = (
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
+ )
+ self.encoder_hidden = (
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
+ )
+ self.encoder_head = (
+ encoder_head if encoder_head is not None else cfg.encoder_head
+ )
+ self.conv_filter_size = (
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
+ )
+ self.conv_kernel_size = (
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
+ )
+ self.encoder_dropout = (
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
+ )
+ self.use_skip_connection = (
+ use_skip_connection
+ if use_skip_connection is not None
+ else cfg.use_skip_connection
+ )
+ self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn
+ self.in_dim = ref_in_dim if ref_in_dim is not None else cfg.ref_in_dim
+ self.out_dim = ref_out_dim if ref_out_dim is not None else cfg.ref_out_dim
+ self.use_query_emb = (
+ use_query_emb if use_query_emb is not None else cfg.use_query_emb
+ )
+ self.num_query_emb = (
+ num_query_emb if num_query_emb is not None else cfg.num_query_emb
+ )
+
+ if self.in_dim != self.encoder_hidden:
+ self.in_linear = nn.Linear(self.in_dim, self.encoder_hidden)
+ self.in_linear.weight.data.normal_(0.0, 0.02)
+ else:
+ self.in_dim = None
+
+ if self.out_dim != self.encoder_hidden:
+ self.out_linear = nn.Linear(self.encoder_hidden, self.out_dim)
+ self.out_linear.weight.data.normal_(0.0, 0.02)
+ else:
+ self.out_linear = None
+
+ self.transformer_encoder = TransformerEncoder(
+ encoder_layer=self.encoder_layer,
+ encoder_hidden=self.encoder_hidden,
+ encoder_head=self.encoder_head,
+ conv_kernel_size=self.conv_kernel_size,
+ conv_filter_size=self.conv_filter_size,
+ encoder_dropout=self.encoder_dropout,
+ use_new_ffn=self.use_new_ffn,
+ use_cln=False,
+ use_skip_connection=False,
+ add_diff_step=False,
+ )
+
+ if self.use_query_emb:
+ # 32 x 512
+ self.query_embs = nn.Embedding(self.num_query_emb, self.encoder_hidden)
+ self.query_attn = nn.MultiheadAttention(
+ self.encoder_hidden, self.encoder_hidden // 64, batch_first=True
+ )
+
+ def forward(self, x_ref, key_padding_mask=None):
+ # x_ref: (B, T, d_ref)
+ # key_padding_mask: (B, T)
+ # return speaker embedding: x_spk
+ # if self.use_query_embs: shape is (B, N_query, d_out)
+ # else: shape is (B, T, d_out)
+
+ if self.in_linear != None:
+ # print('x_ref:',x_ref.shape)
+ x = self.in_linear(x_ref)
+
+ x = self.transformer_encoder(
+ x, key_padding_mask=key_padding_mask, condition=None, diffusion_step=None
+ ) # B, T, d_out
+
+ if self.use_query_emb:
+ spk_query_emb = self.query_embs(
+ torch.arange(self.num_query_emb).to(x.device)
+ ).repeat(x.shape[0], 1, 1)
+ #k/v b x t x d
+ #q b x n x d
+ spk_embs, _ = self.query_attn(
+ query=spk_query_emb,
+ key=x,
+ value=x,
+ key_padding_mask=(
+ ~(key_padding_mask.bool()) if key_padding_mask is not None else None
+ ),
+ )# B, N_query, d_out
+ if self.out_linear != None:
+ spk_embs = self.out_linear(spk_embs)
+
+ else:
+ spk_query_emb = None
+ # B x n x d
+ # b x t x d
+ return spk_embs, x
+
+
+def pad(input_ele, mel_max_length=None):
+ if mel_max_length:
+ max_len = mel_max_length
+ else:
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
+
+ out_list = list()
+ for i, batch in enumerate(input_ele):
+ if len(batch.shape) == 1:
+ one_batch_padded = F.pad(
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
+ )
+ elif len(batch.shape) == 2:
+ one_batch_padded = F.pad(
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
+ )
+ out_list.append(one_batch_padded)
+ out_padded = torch.stack(out_list)
+ return out_padded
+
+
+class FiLM(nn.Module):
+ def __init__(self, in_dim, cond_dim):
+ super().__init__()
+
+ self.gain = Linear(cond_dim, in_dim)
+ self.bias = Linear(cond_dim, in_dim)
+
+ nn.init.xavier_uniform_(self.gain.weight)
+ nn.init.constant_(self.gain.bias, 1)
+
+ nn.init.xavier_uniform_(self.bias.weight)
+ nn.init.constant_(self.bias.bias, 0)
+
+ def forward(self, x, condition):
+ gain = self.gain(condition)
+ bias = self.bias(condition)
+ if gain.dim() == 2:
+ gain = gain.unsqueeze(-1)
+ if bias.dim() == 2:
+ bias = bias.unsqueeze(-1)
+ return x * gain + bias
+
+
+class Mish(nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+def Conv1d(*args, **kwargs):
+ layer = nn.Conv1d(*args, **kwargs)
+ layer.weight.data.normal_(0.0, 0.02)
+ return layer
+
+
+def Linear(*args, **kwargs):
+ layer = nn.Linear(*args, **kwargs)
+ layer.weight.data.normal_(0.0, 0.02)
+ return layer
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, hidden_dim, attn_head, dilation, drop_out, has_cattn=False):
+ super().__init__()
+
+ self.hidden_dim = hidden_dim
+ self.dilation = dilation
+ self.has_cattn = has_cattn
+ self.attn_head = attn_head
+ self.drop_out = drop_out
+
+ self.dilated_conv = Conv1d(
+ hidden_dim, 2 * hidden_dim, 3, padding=dilation, dilation=dilation
+ )
+ self.diffusion_proj = Linear(hidden_dim, hidden_dim)
+
+ self.cond_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
+ self.out_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
+
+ if self.has_cattn:
+ self.attn = nn.MultiheadAttention(
+ hidden_dim, attn_head, 0.1, batch_first=True
+ )
+ self.film = FiLM(hidden_dim * 2, hidden_dim)
+
+ self.ln = nn.LayerNorm(hidden_dim)
+
+ self.dropout = nn.Dropout(self.drop_out)
+
+ def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
+ diffusion_step = self.diffusion_proj(diffusion_step).unsqueeze(-1) # (B, d, 1)
+ cond = self.cond_proj(cond) # (B, 2*d, T)
+
+ y = x + diffusion_step
+ if x_mask != None:
+ y = y * x_mask.to(y.dtype)[:, None, :] # (B, 2*d, T)
+
+ if self.has_cattn:
+ y_ = y.transpose(1, 2)
+ y_ = self.ln(y_)
+
+ y_, _ = self.attn(y_, spk_query_emb, spk_query_emb) # (B, T, d)
+
+ y = self.dilated_conv(y) + cond # (B, 2*d, T)
+
+ if self.has_cattn:
+ y = self.film(y.transpose(1, 2), y_) # (B, T, 2*d)
+ y = y.transpose(1, 2) # (B, 2*d, T)
+
+ gate, filter_ = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter_)
+
+ y = self.out_proj(y)
+
+ residual, skip = torch.chunk(y, 2, dim=1)
+
+ if x_mask != None:
+ residual = residual * x_mask.to(y.dtype)[:, None, :]
+ skip = skip * x_mask.to(y.dtype)[:, None, :]
+
+ return (x + residual) / math.sqrt(2.0), skip
+
+
+class DiffWaveNet(nn.Module):
+ def __init__(
+ self,
+ cfg=None,
+ ):
+ super().__init__()
+
+ self.cfg = cfg
+ self.in_dim = cfg.input_size
+ self.hidden_dim = cfg.hidden_size
+ self.out_dim = cfg.out_size
+ self.num_layers = cfg.num_layers
+ self.cross_attn_per_layer = cfg.cross_attn_per_layer
+ self.dilation_cycle = cfg.dilation_cycle
+ self.attn_head = cfg.attn_head
+ self.drop_out = cfg.drop_out
+
+ self.in_proj = Conv1d(self.in_dim, self.hidden_dim, 1)
+ self.diffusion_embedding = SinusoidalPosEmb(self.hidden_dim)
+
+ self.mlp = nn.Sequential(
+ Linear(self.hidden_dim, self.hidden_dim * 4),
+ Mish(),
+ Linear(self.hidden_dim * 4, self.hidden_dim),
+ )
+
+ self.cond_ln = nn.LayerNorm(self.hidden_dim)
+
+ self.layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ self.hidden_dim,
+ self.attn_head,
+ 2 ** (i % self.dilation_cycle),
+ self.drop_out,
+ has_cattn=(i % self.cross_attn_per_layer == 0),
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ self.skip_proj = Conv1d(self.hidden_dim, self.hidden_dim, 1)
+ self.out_proj = Conv1d(self.hidden_dim, self.out_dim, 1)
+
+ nn.init.zeros_(self.out_proj.weight)
+
+ def forward(
+ self,
+ x,
+ condition_embedding,
+ key_padding_mask=None,
+ reference_embedding=None,
+ diffusion_step=None,
+ ):
+ x = x.transpose(1, 2) # (B, T, d) -> (B, d, T)
+ x_mask = key_padding_mask
+ cond = condition_embedding
+ spk_query_emb = reference_embedding
+ diffusion_step = diffusion_step
+
+ cond = self.cond_ln(cond)
+ cond_input = cond.transpose(1, 2)
+
+ x_input = self.in_proj(x)
+
+ x_input = F.relu(x_input)
+
+ diffusion_step = self.diffusion_embedding(diffusion_step).to(x.dtype)
+ diffusion_step = self.mlp(diffusion_step)
+
+ skip = []
+ for _, layer in enumerate(self.layers):
+ x_input, skip_connection = layer(
+ x_input, x_mask, cond_input, diffusion_step, spk_query_emb
+ )
+ skip.append(skip_connection)
+
+ x_input = torch.sum(torch.stack(skip), dim=0) / math.sqrt(self.num_layers)
+
+ x_out = self.skip_proj(x_input)
+
+ x_out = F.relu(x_out)
+
+ x_out = self.out_proj(x_out) # (B, 80, T)
+
+ x_out = x_out.transpose(1, 2)
+
+ return x_out
+
+
+def override_config(base_config, new_config):
+ """Update new configurations in the original dict with the new dict
+
+ Args:
+ base_config (dict): original dict to be overridden
+ new_config (dict): dict with new configurations
+
+ Returns:
+ dict: updated configuration dict
+ """
+ for k, v in new_config.items():
+ if type(v) == dict:
+ if k not in base_config.keys():
+ base_config[k] = {}
+ base_config[k] = override_config(base_config[k], v)
+ else:
+ base_config[k] = v
+ return base_config
+
+
+def get_lowercase_keys_config(cfg):
+ """Change all keys in cfg to lower case
+
+ Args:
+ cfg (dict): dictionary that stores configurations
+
+ Returns:
+ dict: dictionary that stores configurations
+ """
+ updated_cfg = dict()
+ for k, v in cfg.items():
+ if type(v) == dict:
+ v = get_lowercase_keys_config(v)
+ updated_cfg[k.lower()] = v
+ return updated_cfg
+
+
+
+def save_config(save_path, cfg):
+ """Save configurations into a json file
+
+ Args:
+ save_path (str): path to save configurations
+ cfg (dict): dictionary that stores configurations
+ """
+ with open(save_path, "w") as f:
+ json5.dump(
+ cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True
+ )
+
+
+class JsonHParams:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ if type(v) == dict:
+ v = JsonHParams(**v)
+ self[k] = v
+
+ def keys(self):
+ return self.__dict__.keys()
+
+ def items(self):
+ return self.__dict__.items()
+
+ def values(self):
+ return self.__dict__.values()
+
+ def __len__(self):
+ return len(self.__dict__)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value):
+ return setattr(self, key, value)
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def __repr__(self):
+ return self.__dict__.__repr__()
+
+
+class Noro_VCmodel(nn.Module):
+ def __init__(self, cfg, use_ref_noise = False):
+ super().__init__()
+ self.cfg = cfg
+ self.use_ref_noise = use_ref_noise
+ self.reference_encoder = ReferenceEncoder(cfg=cfg.reference_encoder)
+ if cfg.diffusion.diff_model_type == "WaveNet":
+ self.diffusion = Diffusion(
+ cfg=cfg.diffusion,
+ diff_model=DiffWaveNet(cfg=cfg.diffusion.diff_wavenet),
+ )
+ else:
+ raise NotImplementedError()
+ pitch_dim = 1
+ self.content_f0_enc = nn.Sequential(
+ nn.LayerNorm(
+ cfg.vc_feature.content_feature_dim + pitch_dim
+ ), # 768 (for mhubert) + 1 (for f0)
+ Rearrange("b t d -> b d t"),
+ nn.Conv1d(
+ cfg.vc_feature.content_feature_dim + pitch_dim,
+ cfg.vc_feature.hidden_dim,
+ kernel_size=3,
+ padding=1,
+ ),
+ Rearrange("b d t -> b t d"),
+ )
+
+ self.reset_parameters()
+ def forward(
+ self,
+ x=None,
+ content_feature=None,
+ pitch=None,
+ x_ref=None,
+ x_mask=None,
+ x_ref_mask=None,
+ noisy_x_ref=None
+ ):
+ noisy_reference_embedding = None
+ noisy_condition_embedding = None
+
+ reference_embedding, encoded_x = self.reference_encoder(
+ x_ref=x_ref, key_padding_mask=x_ref_mask
+ )
+
+ # content_feature: B x T x D
+ # pitch: B x T x 1
+ # B x t x D+1
+ # 2B x T
+ condition_embedding = torch.cat([content_feature, pitch[:, :, None]], dim=-1)
+ condition_embedding = self.content_f0_enc(condition_embedding)
+
+ # 2B x T x D
+ if self.use_ref_noise:
+ # noisy_reference
+ noisy_reference_embedding, _ = self.reference_encoder(
+ x_ref=noisy_x_ref, key_padding_mask=x_ref_mask
+ )
+ combined_reference_embedding = (noisy_reference_embedding + reference_embedding) / 2
+ else:
+ combined_reference_embedding = reference_embedding
+
+ combined_condition_embedding = condition_embedding
+
+ diff_out = self.diffusion(
+ x=x,
+ condition_embedding=combined_condition_embedding,
+ x_mask=x_mask,
+ reference_embedding=combined_reference_embedding,
+ )
+ return diff_out, (reference_embedding, noisy_reference_embedding), (condition_embedding, noisy_condition_embedding)
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content_feature=None,
+ pitch=None,
+ x_ref=None,
+ x_ref_mask=None,
+ inference_steps=1000,
+ sigma=1.2,
+ ):
+ reference_embedding, _ = self.reference_encoder(
+ x_ref=x_ref, key_padding_mask=x_ref_mask
+ )
+
+ condition_embedding = torch.cat([content_feature, pitch[:, :, None]], dim=-1)
+ condition_embedding = self.content_f0_enc(condition_embedding)
+
+ bsz, l, _ = condition_embedding.shape
+ if self.cfg.diffusion.diff_model_type == "WaveNet":
+ z = (
+ torch.randn(bsz, l, self.cfg.diffusion.diff_wavenet.input_size).to(
+ condition_embedding.device
+ )
+ / sigma
+ )
+
+ x0 = self.diffusion.reverse_diffusion(
+ z=z,
+ condition_embedding=condition_embedding,
+ x_mask=None,
+ reference_embedding=reference_embedding,
+ n_timesteps=inference_steps,
+ )
+
+ return x0
+
+ def reset_parameters(self):
+ def _reset_parameters(m):
+ if isinstance(m, nn.MultiheadAttention):
+ if m._qkv_same_embed_dim:
+ nn.init.normal_(m.in_proj_weight, std=0.02)
+ else:
+ nn.init.normal_(m.q_proj_weight, std=0.02)
+ nn.init.normal_(m.k_proj_weight, std=0.02)
+ nn.init.normal_(m.v_proj_weight, std=0.02)
+
+ if m.in_proj_bias is not None:
+ nn.init.constant_(m.in_proj_bias, 0.0)
+ nn.init.constant_(m.out_proj.bias, 0.0)
+ if m.bias_k is not None:
+ nn.init.xavier_normal_(m.bias_k)
+ if m.bias_v is not None:
+ nn.init.xavier_normal_(m.bias_v)
+
+ elif (
+ isinstance(m, nn.Conv1d)
+ or isinstance(m, nn.ConvTranspose1d)
+ or isinstance(m, nn.Conv2d)
+ or isinstance(m, nn.ConvTranspose2d)
+ ):
+ m.weight.data.normal_(0.0, 0.02)
+
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ elif isinstance(m, nn.Embedding):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+
+ self.apply(_reset_parameters)
\ No newline at end of file
diff --git a/models/vc/Noro/noro_trainer.py b/models/vc/Noro/noro_trainer.py
new file mode 100644
index 00000000..79de602a
--- /dev/null
+++ b/models/vc/Noro/noro_trainer.py
@@ -0,0 +1,458 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import shutil
+import time
+import json5
+import torch
+import numpy as np
+from tqdm import tqdm
+from utils.util import ValueWindow
+from torch.utils.data import DataLoader
+from models.vc.Noro.noro_base_trainer import Noro_base_Trainer
+from torch.nn import functional as F
+from models.base.base_sampler import VariableSampler
+
+from diffusers import get_scheduler
+import accelerate
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from models.vc.Noro.noro_model import Noro_VCmodel
+from models.vc.Noro.noro_dataset import VCCollator, VCDataset, batch_by_size
+from processors.content_extractor import HubertExtractor
+from models.vc.Noro.noro_loss import diff_loss, ConstractiveSpeakerLoss
+from utils.mel import mel_spectrogram_torch
+from utils.f0 import get_f0_features_using_dio, interpolate
+from torch.nn.utils.rnn import pad_sequence
+
+
+class NoroTrainer(Noro_base_Trainer):
+ def __init__(self, args, cfg):
+ self.args = args
+ self.cfg = cfg
+ cfg.exp_name = args.exp_name
+ self.content_extractor = "mhubert"
+
+ # Initialize accelerator and ensure all processes are ready
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Initialize logger on the main process
+ if self.accelerator.is_main_process:
+ self.logger = get_logger(args.exp_name, log_level="INFO")
+
+ # Configure noise and speaker usage
+ self.use_ref_noise = self.cfg.trans_exp.use_ref_noise
+
+ # Log configuration on the main process
+ if self.accelerator.is_main_process:
+ self.logger.info(f"use_ref_noise: {self.use_ref_noise}")
+
+ # Initialize a time window for monitoring metrics
+ self.time_window = ValueWindow(50)
+
+ # Log the start of training
+ if self.accelerator.is_main_process:
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\tNew training process started.\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+
+ # Initialize checkpoint directory
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Initialize training counters
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+ self.max_epoch = self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ if self.accelerator.is_main_process:
+ self.logger.info(f"Max epoch: {self.max_epoch if self.max_epoch < float('inf') else 'Unlimited'}")
+
+ # Check basic configuration
+ if self.accelerator.is_main_process:
+ self._check_basic_configs()
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.keep_last = [i if i > 0 else float("inf") for i in self.cfg.train.keep_last]
+ self.run_eval = self.cfg.train.run_eval
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ self._set_random_seed(self.cfg.train.random_seed)
+
+ # Setup data loader
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building dataset...")
+ self.train_dataloader = self._build_dataloader()
+ self.speaker_num = len(self.train_dataloader.dataset.speaker2id)
+ if self.accelerator.is_main_process:
+ self.logger.info("Speaker num: {}".format(self.speaker_num))
+
+ # Build model
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building model...")
+ self.model, self.w2v = self._build_model()
+
+ # Resume training if specified
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Resume training: {}".format(args.resume))
+ if args.resume:
+ if self.accelerator.is_main_process:
+ self.logger.info("Resuming from checkpoint...")
+ ckpt_path = self._load_model(
+ self.checkpoint_dir,
+ args.checkpoint_path,
+ resume_type=args.resume_type,
+ )
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ if self.accelerator.is_main_process:
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Initialize optimizer & scheduler
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building optimizer and scheduler...")
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+
+ # Prepare model, w2v, optimizer, and scheduler for accelerator
+ self.model = self._prepare_for_accelerator(self.model)
+ self.w2v = self._prepare_for_accelerator(self.w2v)
+ self.optimizer = self._prepare_for_accelerator(self.optimizer)
+ self.scheduler = self._prepare_for_accelerator(self.scheduler)
+
+ # Build criterion
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building criterion...")
+ self.criterion = self._build_criterion()
+
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+ self.task_type = "VC"
+ self.contrastive_speaker_loss = ConstractiveSpeakerLoss()
+
+ if self.accelerator.is_main_process:
+ self.logger.info("Task type: {}".format(self.task_type))
+
+ def _init_accelerator(self):
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir,
+ logging_dir=os.path.join(self.exp_dir, "log"),
+ )
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ self.accelerator.wait_for_everyone()
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def _build_model(self):
+ w2v = HubertExtractor(self.cfg)
+ model = Noro_VCmodel(cfg=self.cfg.model, use_ref_noise=self.use_ref_noise)
+ return model, w2v
+
+ def _build_dataloader(self):
+ np.random.seed(int(time.time()))
+ if self.accelerator.is_main_process:
+ self.logger.info("Use Dynamic Batchsize...")
+ train_dataset = VCDataset(self.cfg.trans_exp)
+ train_collate = VCCollator(self.cfg)
+ batch_sampler = batch_by_size(
+ train_dataset.num_frame_indices,
+ train_dataset.get_num_frames,
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
+ max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
+ required_batch_size_multiple=self.accelerator.num_processes,
+ )
+ np.random.shuffle(batch_sampler)
+ batches = [
+ x[self.accelerator.local_process_index :: self.accelerator.num_processes]
+ for x in batch_sampler
+ if len(x) % self.accelerator.num_processes == 0
+ ]
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ batch_sampler=VariableSampler(
+ batches, drop_last=False, use_random_sampler=True
+ ),
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ self.accelerator.wait_for_everyone()
+ return train_loader
+
+ def _build_optimizer(self):
+ optimizer = torch.optim.AdamW(
+ filter(lambda p: p.requires_grad, self.model.parameters()),
+ **self.cfg.train.adam,
+ )
+ return optimizer
+
+ def _build_scheduler(self):
+ lr_scheduler = get_scheduler(
+ self.cfg.train.lr_scheduler,
+ optimizer=self.optimizer,
+ num_warmup_steps=self.cfg.train.lr_warmup_steps,
+ num_training_steps=self.cfg.train.num_train_steps
+ )
+ return lr_scheduler
+
+ def _build_criterion(self):
+ criterion = torch.nn.L1Loss(reduction="mean")
+ return criterion
+
+ def _dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+ self.model.load_state_dict(checkpoint["model"])
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
+
+ def _prepare_for_accelerator(self, component):
+ if isinstance(component, dict):
+ for key in component.keys():
+ component[key] = self.accelerator.prepare(component[key])
+ else:
+ component = self.accelerator.prepare(component)
+ return component
+
+ def _train_step(self, batch):
+ total_loss = 0.0
+ train_losses = {}
+ device = self.accelerator.device
+
+ # Move all Tensor data to the specified device
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
+
+ speech = batch["speech"]
+ ref_speech = batch["ref_speech"]
+
+ with torch.set_grad_enabled(False):
+ # Extract features and spectrograms
+ mel = mel_spectrogram_torch(speech, self.cfg).transpose(1, 2)
+ ref_mel = mel_spectrogram_torch(ref_speech,self.cfg).transpose(1, 2)
+ mask = batch["mask"]
+ ref_mask = batch["ref_mask"]
+
+ # Extract pitch and content features
+ audio = speech.cpu().numpy()
+ f0s = []
+ for i in range(audio.shape[0]):
+ wav = audio[i]
+ f0 = get_f0_features_using_dio(wav, self.cfg.preprocess)
+ f0, _ = interpolate(f0)
+ frame_num = len(wav) // self.cfg.preprocess.hop_size
+ f0 = torch.from_numpy(f0[:frame_num]).to(speech.device)
+ f0s.append(f0)
+
+ pitch = pad_sequence(f0s, batch_first=True, padding_value=0).float()
+ pitch = (pitch - pitch.mean(dim=1, keepdim=True)) / (pitch.std(dim=1, keepdim=True) + 1e-6) # Normalize pitch (B,T)
+ _, content_feature = self.w2v.extract_content_features(speech) # semantic (B, T, 768)
+
+ if self.use_ref_noise:
+ noisy_ref_mel = mel_spectrogram_torch(batch["noisy_ref_speech"], self.cfg).transpose(1, 2)
+
+ if self.use_ref_noise:
+ diff_out, (ref_emb, noisy_ref_emb), (cond_emb, _) = self.model(
+ x=mel, content_feature=content_feature, pitch=pitch, x_ref=ref_mel,
+ x_mask=mask, x_ref_mask=ref_mask, noisy_x_ref=noisy_ref_mel
+ )
+ else:
+ diff_out, (ref_emb, _), (cond_emb, _) = self.model(
+ x=mel, content_feature=content_feature, pitch=pitch, x_ref=ref_mel,
+ x_mask=mask, x_ref_mask=ref_mask
+ )
+
+ if self.use_ref_noise:
+ # B x N_query x D
+ ref_emb = torch.mean(ref_emb, dim=1) # B x D
+ noisy_ref_emb = torch.mean(noisy_ref_emb, dim=1) # B x D
+ all_ref_emb = torch.cat([ref_emb, noisy_ref_emb], dim=0) # 2B x D
+ all_speaker_ids = torch.cat([batch["speaker_id"], batch["speaker_id"]], dim=0) # 2B
+ cs_loss = self.contrastive_speaker_loss(all_ref_emb, all_speaker_ids) * 0.25
+ total_loss += cs_loss
+ train_losses["ref_loss"] = cs_loss
+
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], mel, mask=mask)
+ total_loss += diff_loss_x0
+ train_losses["diff_loss_x0"] = diff_loss_x0
+
+ diff_loss_noise = diff_loss(diff_out["noise_pred"], diff_out["noise"], mask=mask)
+ total_loss += diff_loss_noise
+ train_losses["diff_loss_noise"] = diff_loss_noise
+ train_losses["total_loss"] = total_loss
+
+ self.optimizer.zero_grad()
+ self.accelerator.backward(total_loss)
+ if self.accelerator.sync_gradients:
+ self.accelerator.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()), 0.5)
+ self.optimizer.step()
+ self.scheduler.step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ train_losses['learning_rate'] = f"{self.optimizer.param_groups[0]['lr']:.1e}"
+ train_losses["batch_size"] = batch["speaker_id"].shape[0]
+
+ return (train_losses["total_loss"], train_losses, None)
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].train()
+ else:
+ self.model.train()
+ if isinstance(self.w2v, dict):
+ for key in self.w2v.keys():
+ self.w2v[key].eval()
+ else:
+ self.w2v.eval()
+
+ epoch_sum_loss: float = 0.0 # total loss
+ # Put the data to cuda device
+ device = self.accelerator.device
+ with device:
+ torch.cuda.empty_cache()
+ self.model = self.model.to(device)
+ self.w2v = self.w2v.to(device)
+
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ speech = batch["speech"].cpu().numpy()
+ speech = speech[0]
+ self.batch_count += 1
+ self.step += 1
+ if len(speech) >= 16000 * 25:
+ continue
+ with self.accelerator.accumulate(self.model):
+ total_loss, train_losses, _ = self._train_step(batch)
+
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += total_loss
+ self.current_loss = total_loss
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.step,
+ )
+ if self.accelerator.is_main_process and self.batch_count % 10 == 0:
+ self.echo_log(train_losses, mode="Training")
+
+ self.save_checkpoint()
+ self.accelerator.wait_for_everyone()
+
+ return epoch_sum_loss, None
+
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # Dump config file
+ if self.accelerator.is_main_process:
+ self._dump_cfg(self.config_save_path)
+
+ # Wait to ensure good to go
+ self.accelerator.wait_for_everyone()
+ # Stop when meeting max epoch or self.cfg.train.num_train_steps
+ while self.epoch < self.max_epoch and self.step < self.cfg.train.num_train_steps:
+ if self.accelerator.is_main_process:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+ self.logger.info("Start training...")
+
+ train_total_loss, _ = self._train_epoch()
+
+ self.epoch += 1
+ self.accelerator.wait_for_everyone()
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key].step()
+ else:
+ self.scheduler.step()
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, train_total_loss
+ ),
+ )
+ )
+ self.accelerator.end_training()
+ if self.accelerator.is_main_process:
+ self.logger.info("Training finished...")
+
+ def save_checkpoint(self):
+ self.accelerator.wait_for_everyone()
+ # Main process only
+ if self.accelerator.is_main_process:
+ if self.batch_count % self.save_checkpoint_stride[0] == 0:
+ keep_last = self.keep_last[0]
+ # Read all folders in self.checkpoint_dir
+ all_ckpts = os.listdir(self.checkpoint_dir)
+ # Exclude non-folders
+ all_ckpts = [ckpt for ckpt in all_ckpts if os.path.isdir(os.path.join(self.checkpoint_dir, ckpt))]
+ if len(all_ckpts) > keep_last:
+ # Keep only the last keep_last folders in self.checkpoint_dir, sorted by step "epoch-{:04d}_step-{:07d}_loss-{:.6f}"
+ all_ckpts = sorted(all_ckpts, key=lambda x: int(x.split("_")[1].split('-')[1]))
+ for ckpt in all_ckpts[:-keep_last]:
+ shutil.rmtree(os.path.join(self.checkpoint_dir, ckpt))
+ checkpoint_filename = "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, self.current_loss
+ )
+ path = os.path.join(self.checkpoint_dir, checkpoint_filename)
+ self.logger.info("Saving state to {}...".format(path))
+ self.accelerator.save_state(path)
+ self.logger.info("Finished saving state.")
+ self.accelerator.wait_for_everyone()
diff --git a/processors/audio_features_extractor.py b/processors/audio_features_extractor.py
index 8e38bd5e..0018fcf6 100644
--- a/processors/audio_features_extractor.py
+++ b/processors/audio_features_extractor.py
@@ -25,6 +25,7 @@
WhisperExtractor,
ContentvecExtractor,
WenetExtractor,
+ HubertExtractor
)
@@ -155,3 +156,20 @@ def get_wenet_features(self, wavs, target_frame_len, wav_lens=None):
wenet_feats = self.wenet_extractor.extract_content_features(wavs, lens=wav_lens)
wenet_feats = self.wenet_extractor.ReTrans(wenet_feats, target_frame_len)
return wenet_feats
+
+ def get_hubert_features(self, wavs):
+ """Get HuBERT Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+
+ Returns:
+ Tensor whose shape is (B, T, D)
+ """
+ if not hasattr(self, "model"):
+ self.hubert_extractor = HubertExtractor(self.cfg)
+
+ clusters, hubert_feats = self.hubert_extractor.extract_content_features(wavs)
+
+
+ return clusters, hubert_feats
diff --git a/processors/content_extractor.py b/processors/content_extractor.py
index 34b54917..c45e587a 100644
--- a/processors/content_extractor.py
+++ b/processors/content_extractor.py
@@ -14,6 +14,10 @@
from torch.utils.data import DataLoader
from fairseq import checkpoint_utils
from transformers import AutoModel, Wav2Vec2FeatureExtractor
+import torch.nn as nn
+import torch.nn.functional as F
+import joblib
+from einops import repeat
from utils.io_optim import (
TorchaudioDataset,
@@ -493,6 +497,60 @@ def extract_content_features(self, wavs):
mert_features.append(feature)
return mert_features
+
+class HubertExtractor(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.load_model()
+
+ def load_model(self):
+ kmeans_model_path = self.cfg.preprocess.kmeans_model_path
+ hubert_model_path = self.cfg.preprocess.hubert_model_path
+ print("Load Hubert Model...")
+ checkpoint = torch.load(hubert_model_path)
+ load_model_input = {hubert_model_path: checkpoint}
+ model, *_ = checkpoint_utils.load_model_ensemble_and_task(
+ load_model_input
+ )
+ self.model = model[0]
+ self.model.eval()
+
+ # Load KMeans cluster centers
+ kmeans = joblib.load(kmeans_model_path)
+ self.kmeans = kmeans
+
+ self.register_buffer(
+ "cluster_centers", torch.from_numpy(kmeans.cluster_centers_)
+ )
+
+ def extract_content_features(self, wav_input):
+ """
+ Extract content features and quantize using KMeans clustering.
+
+ Args:
+ audio_data: tensor (batch_size, T)
+
+ Returns:
+ quantize: tensor (batch_size, T, 768)
+ """
+ # Extract features using HuBERT
+ wav_input = F.pad(wav_input, (40, 40), "reflect")
+ embed = self.model(
+ wav_input,
+ features_only=True,
+ mask=False,
+ )
+
+ batched_cluster_centers = repeat(
+ self.cluster_centers, "c d -> b c d", b=embed.shape[0]
+ )
+
+ dists = -torch.cdist(embed, batched_cluster_centers, p=2)
+ clusters = dists.argmax(dim=-1) # (batch, seq_len)
+ quantize = F.embedding(clusters, self.cluster_centers)
+
+ return clusters, quantize
def extract_utt_content_features_dataloader(cfg, metadata, num_workers):