diff --git a/viettts/tts.py b/viettts/tts.py index 13e2434..e4e6266 100644 --- a/viettts/tts.py +++ b/viettts/tts.py @@ -67,8 +67,8 @@ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed def tts_to_wav(self, text, prompt_speech_16k, speed=1.0): wavs = [] for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed): - wavs.append(output['tts_speech']) - return np.concatenate(wavs, axis=0).flatten() + wavs.append(output['tts_speech'].squeeze(0).numpy()) + return np.concatenate(wavs, axis=0) def tts_to_file(self, text, prompt_speech_16k, speed, output_path): wav = self.tts_to_wav(text, prompt_speech_16k, speed)