From 681567df5808fef96e95bbacd3a9f4d10be935df Mon Sep 17 00:00:00 2001 From: Lengyue Date: Wed, 20 Dec 2023 02:21:02 +0000 Subject: [PATCH] Add webui --- docs/zh/inference.md | 10 ++ fish_speech/webui/app.py | 320 +++++++++++++++++++++++++++++++++------ 2 files changed, 283 insertions(+), 47 deletions(-) diff --git a/docs/zh/inference.md b/docs/zh/inference.md index e6aac263..e767fe63 100644 --- a/docs/zh/inference.md +++ b/docs/zh/inference.md @@ -76,3 +76,13 @@ python -m zibai tools.api_server:app --listen 127.0.0.1:8000 随后, 你可以在 `http://127.0.0.1:8000/docs` 中查看并测试 API. 一般来说, 你需要先调用 `PUT /v1/models/default` 来加载模型, 然后调用 `POST /v1/models/default/invoke` 来进行推理. 具体的参数请参考 API 文档. + +## WebUI 推理 +在运行 WebUI 之前, 你需要先启动 HTTP 服务, 如上所述. + +随后你可以使用以下命令来启动 WebUI: +```bash +python fish_speech/webui/app.py +``` + +祝大家玩得开心! diff --git a/fish_speech/webui/app.py b/fish_speech/webui/app.py index 99fc858b..ee9e51f5 100644 --- a/fish_speech/webui/app.py +++ b/fish_speech/webui/app.py @@ -1,14 +1,17 @@ import html +import io import traceback import gradio as gr +import librosa +import requests from fish_speech.text import parse_text_to_segments, segments_to_phones HEADER_MD = """ # Fish Speech -基于 VITS 和 GPT 的多语种语音合成. 项目很大程度上基于 Rcell 的 GPT-VITS. +基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路. """ TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如: @@ -66,7 +69,7 @@ def prepare_text( else: reference_text = "" - if input_mode != "自动音素转换": + if input_mode != "自动音素": return [ [idx, reference_text + line, "-", "-"] for idx, line in enumerate(lines) @@ -92,69 +95,272 @@ def prepare_text( return rows, None +def load_model( + server_url, + llama_ckpt_path, + llama_config_name, + tokenizer, + vqgan_ckpt_path, + vqgan_config_name, + device, + precision, + compile_model, +): + payload = { + "device": device, + "llama": { + "config_name": llama_config_name, + "checkpoint_path": llama_ckpt_path, + "precision": precision, + "tokenizer": tokenizer, + "compile": compile_model, + }, + "vqgan": { + "config_name": vqgan_config_name, + "checkpoint_path": vqgan_ckpt_path, + }, + } + + try: + resp = requests.put(f"{server_url}/v1/models/default", json=payload) + resp.raise_for_status() + except Exception: + traceback.print_exc() + err = traceback.format_exc() + return build_html_error_message(f"加载模型时发生错误. \n\n{err}") + + return "模型加载成功." + + +def build_model_config_block(): + server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000") + + with gr.Row(): + with gr.Column(scale=1): + device = gr.Dropdown( + label="设备", + choices=["cpu", "cuda"], + value="cuda", + ) + with gr.Column(scale=1): + precision = gr.Dropdown( + label="精度", + choices=["bfloat16", "float16"], + value="float16", + ) + with gr.Column(scale=1): + compile_model = gr.Checkbox( + label="编译模型", + value=True, + ) + + llama_ckpt_path = gr.Textbox( + label="Llama 模型路径", value="checkpoints/text2semantic-400m-v0.2-4k.pth" + ) + llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune") + tokenizer = gr.Textbox(label="Tokenizer", value="fishaudio/speech-lm-v1") + + vqgan_ckpt_path = gr.Textbox(label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth") + vqgan_config_name = gr.Textbox(label="VQGAN 配置文件", value="vqgan_pretrain") + + load_model_btn = gr.Button(value="加载模型", variant="primary") + error = gr.HTML(label="错误信息") + + load_model_btn.click( + load_model, + [ + server_url, + llama_ckpt_path, + llama_config_name, + tokenizer, + vqgan_ckpt_path, + vqgan_config_name, + device, + precision, + compile_model, + ], + [error], + ) + + return server_url + + +def inference( + server_url, + text, + input_mode, + language0, + language1, + language2, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + top_k, + top_p, + repetition_penalty, + temperature, + speaker, +): + languages = [language0, language1, language2] + languages = [ + { + "中文": "zh", + "日文": "jp", + "英文": "en", + }[language] + for language in languages + ] + + if len(set(languages)) != len(languages): + return [], build_html_error_message("语言优先级不能重复.") + + order = ",".join(languages) + payload = { + "text": text, + "prompt_text": reference_text if enable_reference_audio else None, + "prompt_tokens": reference_audio if enable_reference_audio else None, + "max_new_tokens": int(max_new_tokens), + "top_k": int(top_k) if top_k > 0 else None, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "order": order, + "use_g2p": input_mode == "自动音素", + "seed": None, + "speaker": speaker if speaker.strip() != "" else None, + } + + try: + resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload) + resp.raise_for_status() + except Exception: + traceback.print_exc() + err = traceback.format_exc() + return [], build_html_error_message(f"推理时发生错误. \n\n{err}") + + content = io.BytesIO(resp.content) + content.seek(0) + content, sr = librosa.load(content, sr=None, mono=True) + + return (sr, content), None + + with gr.Blocks(theme=gr.themes.Base()) as app: gr.Markdown(HEADER_MD) + # Use light theme by default + app.load( + None, + None, + js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}", + ) + + # Inference with gr.Row(): with gr.Column(scale=3): - text = gr.Textbox(label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=3) - - with gr.Row(): - with gr.Tab(label="合成参数"): - gr.Markdown("配置常见合成参数.") + with gr.Tab(label="模型配置"): + server_url = build_model_config_block() + + with gr.Tab(label="推理配置"): + text = gr.Textbox( + label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15 + ) + + with gr.Row(): + with gr.Tab(label="合成参数"): + gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.") + + input_mode = gr.Dropdown( + choices=["文本", "自动音素"], + value="文本", + label="输入模式", + ) - input_mode = gr.Dropdown( - choices=["手动输入音素/文本", "自动音素转换"], - value="手动输入音素/文本", - label="输入模式", - ) + max_new_tokens = gr.Slider( + label="最大生成 Token 数", + minimum=0, + maximum=4096, + value=0, # 0 means no limit + step=8, + ) - with gr.Tab(label="语言优先级"): - gr.Markdown("该参数只在自动音素转换时生效.") + top_k = gr.Slider( + label="Top-K", minimum=0, maximum=100, value=0, step=1 + ) - with gr.Column(scale=1): - language0 = gr.Dropdown( - choices=["中文", "日文", "英文"], - label="语言 1", - value="中文", + top_p = gr.Slider( + label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01 ) - with gr.Column(scale=1): - language1 = gr.Dropdown( - choices=["中文", "日文", "英文"], - label="语言 2", - value="日文", + repetition_penalty = gr.Slider( + label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01 ) - with gr.Column(scale=1): - language2 = gr.Dropdown( - choices=["中文", "日文", "英文"], - label="语言 3", - value="英文", + temperature = gr.Slider( + label="温度", minimum=0, maximum=2, value=0.7, step=0.01 ) - with gr.Tab(label="参考音频"): - gr.Markdown("3 秒左右的参考音频, 适用于无微调直接推理.") + speaker = gr.Textbox( + label="说话人", + placeholder="说话人", + lines=1, + ) - enable_reference_audio = gr.Checkbox(label="启用参考音频", value=False) - reference_audio = gr.Audio(label="参考音频") - reference_text = gr.Textbox( - label="参考文本", - placeholder="参考文本", - lines=1, - value="万一他很崇拜我们呢? 嘿嘿.", - ) + with gr.Tab(label="语言优先级"): + gr.Markdown("该参数只在自动音素转换时生效.") + + with gr.Column(scale=1): + language0 = gr.Dropdown( + choices=["中文", "日文", "英文"], + label="语言 1", + value="中文", + ) + + with gr.Column(scale=1): + language1 = gr.Dropdown( + choices=["中文", "日文", "英文"], + label="语言 2", + value="日文", + ) + + with gr.Column(scale=1): + language2 = gr.Dropdown( + choices=["中文", "日文", "英文"], + label="语言 3", + value="英文", + ) + + with gr.Tab(label="参考音频"): + gr.Markdown("5-10 秒的参考音频, 适用于指定音色.") + + enable_reference_audio = gr.Checkbox( + label="启用参考音频", value=False + ) + reference_audio = gr.Audio( + label="参考音频", + value="docs/assets/audios/0_input.wav", + type="filepath", + ) + reference_text = gr.Textbox( + label="参考文本", + placeholder="参考文本", + lines=1, + value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + ) - with gr.Row(): - with gr.Column(scale=2): - generate = gr.Button(value="合成", variant="primary") - with gr.Column(scale=1): - clear = gr.Button(value="清空") + with gr.Row(): + with gr.Column(scale=2): + generate = gr.Button(value="合成", variant="primary") + with gr.Column(scale=1): + clear = gr.Button(value="清空") with gr.Column(scale=3): error = gr.HTML(label="错误信息") - parsed_text = gr.Dataframe(label="解析结果", headers=["ID", "文本", "语言", "音素"]) - audio = gr.Audio(label="合成音频") + parsed_text = gr.Dataframe( + label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"] + ) + audio = gr.Audio(label="合成音频", type="numpy") # Language & Text Parsing kwargs = dict( @@ -178,7 +384,27 @@ def prepare_text( enable_reference_audio.change(prepare_text, **kwargs) # Submit - generate.click(lambda: None, outputs=[audio]) + generate.click( + inference, + [ + server_url, + text, + input_mode, + language0, + language1, + language2, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + top_k, + top_p, + repetition_penalty, + temperature, + speaker, + ], + [audio, error], + ) if __name__ == "__main__":