Skip to content

Commit

Permalink
Streaming support (#150)
Browse files Browse the repository at this point in the history
* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] authored May 4, 2024
1 parent 7c7eced commit d89a0d4
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ filelists
/data
/.idea
ffmpeg.exe
ffprobe.exe
asr-label-win-x64.exe
/.cache
/fishenv
Expand Down
3 changes: 3 additions & 0 deletions fish_speech/i18n/locale/en_US.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
"Speaker": "Speaker",
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
"Start Training": "Start Training",
"Streaming": "Streaming",
"Streaming Audio": "Streaming Audio",
"Streaming Generate": "Streaming Generate",
"Tensorboard Host": "Tensorboard Host",
"Tensorboard Log Path": "Tensorboard Log Path",
"Tensorboard Port": "Tensorboard Port",
Expand Down
3 changes: 3 additions & 0 deletions fish_speech/i18n/locale/es_ES.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
"Speaker": "Hablante",
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
"Start Training": "Iniciar Entrenamiento",
"Streaming": "streaming",
"Streaming Audio": "transmisión de audio",
"Streaming Generate": "síntesis en flujo",
"Tensorboard Host": "Host de Tensorboard",
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
"Tensorboard Port": "Puerto de Tensorboard",
Expand Down
3 changes: 3 additions & 0 deletions fish_speech/i18n/locale/ja_JP.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
"Speaker": "話者",
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
"Start Training": "トレーニング開始",
"Streaming": "ストリーミング",
"Streaming Audio": "ストリーミングオーディオ",
"Streaming Generate": "ストリーミング合成",
"Tensorboard Host": "Tensorboardホスト",
"Tensorboard Log Path": "Tensorboardログパス",
"Tensorboard Port": "Tensorboardポート",
Expand Down
3 changes: 3 additions & 0 deletions fish_speech/i18n/locale/zh_CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
"Speaker": "说话人",
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
"Start Training": "开始训练",
"Streaming": "流式输出",
"Streaming Audio": "流式音频",
"Streaming Generate": "流式合成",
"Tensorboard Host": "Tensorboard 监听地址",
"Tensorboard Log Path": "Tensorboard 日志路径",
"Tensorboard Port": "Tensorboard 端口",
Expand Down
16 changes: 16 additions & 0 deletions fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def show_selected(options):
return i18n("No selected options")


from pydub import AudioSegment


def convert_to_mono_in_place(audio_path):
audio = AudioSegment.from_file(audio_path)
if audio.channels > 1:
mono_audio = audio.set_channels(1)
mono_audio.export(audio_path, format="mp3")
logger.info(f"Convert {audio_path} successfully")


def list_copy(list_file_path, method):
wav_root = data_pre_output
lst = []
Expand All @@ -266,6 +277,7 @@ def list_copy(list_file_path, method):
if target_wav_path.is_file():
continue
target_wav_path.parent.mkdir(parents=True, exist_ok=True)
convert_to_mono_in_place(original_wav_path)
if method == i18n("Copy"):
shutil.copy(original_wav_path, target_wav_path)
else:
Expand Down Expand Up @@ -300,6 +312,10 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
tar_path = data_path / item_path.name

if content["type"] == "folder" and item_path.is_dir():
for suf in ["wav", "flac", "mp3"]:
for audio_path in item_path.glob(f"**/*.{suf}"):
convert_to_mono_in_place(audio_path)

cur_lang = content["label_lang"]
if cur_lang != "IGNORE":
try:
Expand Down
141 changes: 138 additions & 3 deletions tools/webui.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import gc
import html
import io
import os
import queue
import wave
from argparse import ArgumentParser
from pathlib import Path

import gradio as gr
import librosa
import numpy as np
import pyrootutils
import torch
from loguru import logger
Expand Down Expand Up @@ -155,6 +158,113 @@ def inference(
return (vqgan_model.sampling_rate, fake_audios), None


def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes


@torch.inference_mode
def inference_stream(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
speaker,
):
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
yield (
None,
i18n("Text is too long, please keep it under {} characters.").format(
args.max_gradio_length
),
)

# Parse reference audio aka prompt
prompt_tokens = None
if enable_reference_audio and reference_audio is not None:
# reference_audio_sr, reference_audio_content = reference_audio
reference_audio_content, _ = librosa.load(
reference_audio, sr=vqgan_model.sampling_rate, mono=True
)
audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
None, None, :
]

logger.info(
f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
)

# VQ Encoder
audio_lengths = torch.tensor(
[audios.shape[2]], device=vqgan_model.device, dtype=torch.long
)
prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]

# LLAMA Inference
request = dict(
tokenizer=llama_tokenizer,
device=vqgan_model.device,
max_new_tokens=max_new_tokens,
text=text,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
compile=args.compile,
iterative_prompt=chunk_length > 0,
chunk_length=chunk_length,
max_length=args.max_length,
speaker=speaker if speaker else None,
prompt_tokens=prompt_tokens if enable_reference_audio else None,
prompt_text=reference_text if enable_reference_audio else None,
is_streaming=True,
)

payload = dict(
response_queue=queue.Queue(),
request=request,
)
llama_queue.put(payload)

yield wav_chunk_header(), None
while True:
result = payload["response_queue"].get()
if result == "next":
# TODO: handle next sentence
continue

if result == "done":
if payload["success"] is False:
yield None, build_html_error_message(payload["response"])
break

# VQGAN Inference
feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
fake_audios = vqgan_model.decode(
indices=result[None], feature_lengths=feature_lengths, return_audios=True
)[0, 0]
fake_audios = fake_audios.float().cpu().numpy()
yield (
np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
).astype(np.int16).tobytes(), None

if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

pass


def build_app():
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
Expand Down Expand Up @@ -243,13 +353,22 @@ def build_app():
error = gr.HTML(label=i18n("Error Message"))
with gr.Row():
audio = gr.Audio(label=i18n("Generated Audio"), type="numpy")

with gr.Row():
stream_audio = gr.Audio(
label=i18n("Streaming Audio"),
streaming=True,
autoplay=True,
interactive=False,
)
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
)

generate_stream = gr.Button(
value="\U0001F3A7 " + i18n("Streaming Generate"),
variant="primary",
)
# # Submit
generate.click(
inference,
Expand All @@ -268,7 +387,23 @@ def build_app():
[audio, error],
concurrency_limit=1,
)

generate_stream.click(
inference_stream,
[
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
speaker,
],
[stream_audio, error],
concurrency_limit=10,
)
return app


Expand Down

0 comments on commit d89a0d4

Please sign in to comment.