Skip to content

Commit

Permalink
🚀 ✍ Update Config for FastSpeech2_v2 small, add test multi-speaker fo…
Browse files Browse the repository at this point in the history
…r FastSpeech2.
  • Loading branch information
dathudeptrai committed Jul 9, 2020
1 parent f65547c commit a9a48a7
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 20 deletions.
36 changes: 18 additions & 18 deletions examples/fastspeech2/conf/fastspeech2.v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,31 @@ format: "npy"
###########################################################
fastspeech_params:
n_speakers: 1
hidden_size: 384
num_hidden_layers: 4
num_attention_heads: 2
attention_head_size: 8 # in v1, attention_head_size = hidden_size // num_attention_heads
intermediate_size: 1024
intermediate_kernel_size: 3
num_duration_conv_layers: 2
duration_predictor_filters: 256
duration_predictor_kernel_sizes: 3
encoder_hidden_size: 256
encoder_num_hidden_layers: 3
encoder_num_attention_heads: 2
encoder_attention_head_size: 8 # in v1, = 384//2
encoder_intermediate_size: 1024
encoder_intermediate_kernel_size: 3
encoder_hidden_act: "mish"
decoder_hidden_size: 256
decoder_num_hidden_layers: 3
decoder_num_attention_heads: 2
decoder_attention_head_size: 8 # in v1, = 384//2
decoder_intermediate_size: 768
decoder_intermediate_kernel_size: 3
decoder_hidden_act: "mish"
variant_prediction_num_conv_layers: 2
variant_predictor_filter: 256
variant_predictor_kernel_size: 3
variant_predictor_dropout_rate: 0.5
num_mels: 80
hidden_act: "mish"
hidden_dropout_prob: 0.2
attention_probs_dropout_prob: 0.1
duration_predictor_dropout_probs: 0.5
max_position_embeddings: 2048
initializer_range: 0.02
output_attentions: False
output_hidden_states: False
f0_energy_predictor_filters: 256
f0_energy_predictor_kernel_sizes: 3
f0_energy_predictor_dropout_probs: 0.5
f0_kernel_size: 9
energy_kernel_size: 9
f0_dropout_rate: 0.5
energy_dropout_rate: 0.5

###########################################################
# DATA LOADER SETTING #
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_tts/models/fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def __init__(self, config, **kwargs):
if config.n_speakers > 1:
self.decoder_speaker_embeddings = tf.keras.layers.Embedding(
config.n_speakers,
config.hidden_size,
config.encoder_self_attention_params.hidden_size,
embeddings_initializer=get_initializer(config.initializer_range),
name="speaker_embeddings",
)
self.speaker_fc = tf.keras.layers.Dense(
units=config.hidden_size, name="speaker_fc"
units=config.encoder_self_attention_params.hidden_size, name="speaker_fc"
)

self.config = config
Expand Down
75 changes: 75 additions & 0 deletions test/test_fastspeech2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pytest
import tensorflow as tf

from tensorflow_tts.models import TFFastSpeech2
from tensorflow_tts.configs import FastSpeech2Config

os.environ["CUDA_VISIBLE_DEVICES"] = ""

logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)


@pytest.mark.parametrize("num_hidden_layers,n_speakers", [(2, 1), (3, 2), (4, 3)])
def test_fastspeech_trainable(num_hidden_layers, n_speakers):
config = FastSpeech2Config(
encoder_num_hidden_layers=num_hidden_layers,
decoder_num_hidden_layers=num_hidden_layers + 1,
n_speakers=n_speakers
)

fastspeech2 = TFFastSpeech2(config, name="fastspeech")
optimizer = tf.keras.optimizers.Adam(lr=0.001)

# fake inputs
input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
attention_mask = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
speaker_ids = tf.convert_to_tensor([0], tf.int32)
duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
f0_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.float32)
energy_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.float32)

mel_gts = tf.random.uniform(shape=[1, 10, 80], dtype=tf.float32)

@tf.function
def one_step_training():
with tf.GradientTape() as tape:
mel_outputs_before, _, duration_outputs, _, _ = fastspeech2(
input_ids, attention_mask, speaker_ids, duration_gts, f0_gts, energy_gts, training=True
)
duration_loss = tf.keras.losses.MeanSquaredError()(
duration_gts, duration_outputs
)
mel_loss = tf.keras.losses.MeanSquaredError()(mel_gts, mel_outputs_before)
loss = duration_loss + mel_loss
gradients = tape.gradient(loss, fastspeech2.trainable_variables)
optimizer.apply_gradients(zip(gradients, fastspeech2.trainable_variables))

tf.print(loss)

import time

for i in range(2):
if i == 1:
start = time.time()
one_step_training()
print(time.time() - start)

0 comments on commit a9a48a7

Please sign in to comment.