-
Notifications
You must be signed in to change notification settings - Fork 699
/
Copy pathquick_start.py
122 lines (93 loc) · 4.6 KB
/
quick_start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from models.fatchord_version import WaveRNN
from utils import hparams as hp
from utils.text.symbols import symbols
from models.tacotron import Tacotron
import argparse
from utils.text import text_to_sequence
from utils.display import save_attention, simple_table
import zipfile, os
os.makedirs('quick_start/tts_weights/', exist_ok=True)
os.makedirs('quick_start/voc_weights/', exist_ok=True)
zip_ref = zipfile.ZipFile('pretrained/ljspeech.wavernn.mol.800k.zip', 'r')
zip_ref.extractall('quick_start/voc_weights/')
zip_ref.close()
zip_ref = zipfile.ZipFile('pretrained/ljspeech.tacotron.r2.180k.zip', 'r')
zip_ref.extractall('quick_start/tts_weights/')
zip_ref.close()
if __name__ == "__main__":
# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation (lower quality)')
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slower Unbatched Generation (better quality)')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py',
help='The file to use for the hyperparameters')
args = parser.parse_args()
hp.configure(args.hp_file) # Load hparams from file
parser.set_defaults(batched=True)
parser.set_defaults(input_text=None)
batched = args.batched
input_text = args.input_text
if not args.force_cpu and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('Using device:', device)
print('\nInitialising WaveRNN Model...\n')
# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode='MOL').to(device)
voc_model.load('quick_start/voc_weights/latest_weights.pyt')
print('\nInitialising Tacotron Model...\n')
# Instantiate Tacotron Model
tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
num_chars=len(symbols),
encoder_dims=hp.tts_encoder_dims,
decoder_dims=hp.tts_decoder_dims,
n_mels=hp.num_mels,
fft_bins=hp.num_mels,
postnet_dims=hp.tts_postnet_dims,
encoder_K=hp.tts_encoder_K,
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)
tts_model.load('quick_start/tts_weights/latest_weights.pyt')
if input_text:
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
else:
with open('sentences.txt') as f:
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000
r = tts_model.r
simple_table([('WaveRNN', str(voc_k) + 'k'),
(f'Tacotron(r={r})', str(tts_k) + 'k'),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', 11_000 if batched else 'N/A'),
('Overlap Samples', 550 if batched else 'N/A')])
for i, x in enumerate(inputs, 1):
print(f'\n| Generating {i}/{len(inputs)}')
_, m, attention = tts_model.generate(x)
if input_text:
save_path = f'quick_start/__input_{input_text[:10]}_{tts_k}k.wav'
else:
save_path = f'quick_start/{i}_batched{str(batched)}_{tts_k}k.wav'
# save_attention(attention, save_path)
m = torch.tensor(m).unsqueeze(0)
m = (m + 4) / 8
voc_model.generate(m, save_path, batched, 11_000, 550, hp.mu_law)
print('\n\nDone.\n')