Skip to content

Commit

Permalink
Add webui
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 20, 2023
1 parent 3a08434 commit 681567d
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 47 deletions.
10 changes: 10 additions & 0 deletions docs/zh/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,13 @@ python -m zibai tools.api_server:app --listen 127.0.0.1:8000
随后, 你可以在 `http://127.0.0.1:8000/docs` 中查看并测试 API.
一般来说, 你需要先调用 `PUT /v1/models/default` 来加载模型, 然后调用 `POST /v1/models/default/invoke` 来进行推理. 具体的参数请参考 API 文档.


## WebUI 推理
在运行 WebUI 之前, 你需要先启动 HTTP 服务, 如上所述.

随后你可以使用以下命令来启动 WebUI:
```bash
python fish_speech/webui/app.py
```

祝大家玩得开心!
320 changes: 273 additions & 47 deletions fish_speech/webui/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import html
import io
import traceback

import gradio as gr
import librosa
import requests

from fish_speech.text import parse_text_to_segments, segments_to_phones

HEADER_MD = """
# Fish Speech
基于 VITSGPT 的多语种语音合成. 项目很大程度上基于 Rcell 的 GPT-VITS.
基于 VQ-GANLlama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
"""

TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
Expand Down Expand Up @@ -66,7 +69,7 @@ def prepare_text(
else:
reference_text = ""

if input_mode != "自动音素转换":
if input_mode != "自动音素":
return [
[idx, reference_text + line, "-", "-"]
for idx, line in enumerate(lines)
Expand All @@ -92,69 +95,272 @@ def prepare_text(
return rows, None


def load_model(
server_url,
llama_ckpt_path,
llama_config_name,
tokenizer,
vqgan_ckpt_path,
vqgan_config_name,
device,
precision,
compile_model,
):
payload = {
"device": device,
"llama": {
"config_name": llama_config_name,
"checkpoint_path": llama_ckpt_path,
"precision": precision,
"tokenizer": tokenizer,
"compile": compile_model,
},
"vqgan": {
"config_name": vqgan_config_name,
"checkpoint_path": vqgan_ckpt_path,
},
}

try:
resp = requests.put(f"{server_url}/v1/models/default", json=payload)
resp.raise_for_status()
except Exception:
traceback.print_exc()
err = traceback.format_exc()
return build_html_error_message(f"加载模型时发生错误. \n\n{err}")

return "模型加载成功."


def build_model_config_block():
server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")

with gr.Row():
with gr.Column(scale=1):
device = gr.Dropdown(
label="设备",
choices=["cpu", "cuda"],
value="cuda",
)
with gr.Column(scale=1):
precision = gr.Dropdown(
label="精度",
choices=["bfloat16", "float16"],
value="float16",
)
with gr.Column(scale=1):
compile_model = gr.Checkbox(
label="编译模型",
value=True,
)

llama_ckpt_path = gr.Textbox(
label="Llama 模型路径", value="checkpoints/text2semantic-400m-v0.2-4k.pth"
)
llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
tokenizer = gr.Textbox(label="Tokenizer", value="fishaudio/speech-lm-v1")

vqgan_ckpt_path = gr.Textbox(label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth")
vqgan_config_name = gr.Textbox(label="VQGAN 配置文件", value="vqgan_pretrain")

load_model_btn = gr.Button(value="加载模型", variant="primary")
error = gr.HTML(label="错误信息")

load_model_btn.click(
load_model,
[
server_url,
llama_ckpt_path,
llama_config_name,
tokenizer,
vqgan_ckpt_path,
vqgan_config_name,
device,
precision,
compile_model,
],
[error],
)

return server_url


def inference(
server_url,
text,
input_mode,
language0,
language1,
language2,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
top_k,
top_p,
repetition_penalty,
temperature,
speaker,
):
languages = [language0, language1, language2]
languages = [
{
"中文": "zh",
"日文": "jp",
"英文": "en",
}[language]
for language in languages
]

if len(set(languages)) != len(languages):
return [], build_html_error_message("语言优先级不能重复.")

order = ",".join(languages)
payload = {
"text": text,
"prompt_text": reference_text if enable_reference_audio else None,
"prompt_tokens": reference_audio if enable_reference_audio else None,
"max_new_tokens": int(max_new_tokens),
"top_k": int(top_k) if top_k > 0 else None,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"order": order,
"use_g2p": input_mode == "自动音素",
"seed": None,
"speaker": speaker if speaker.strip() != "" else None,
}

try:
resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
resp.raise_for_status()
except Exception:
traceback.print_exc()
err = traceback.format_exc()
return [], build_html_error_message(f"推理时发生错误. \n\n{err}")

content = io.BytesIO(resp.content)
content.seek(0)
content, sr = librosa.load(content, sr=None, mono=True)

return (sr, content), None


with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)

# Use light theme by default
app.load(
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
)

# Inference
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=3)

with gr.Row():
with gr.Tab(label="合成参数"):
gr.Markdown("配置常见合成参数.")
with gr.Tab(label="模型配置"):
server_url = build_model_config_block()

with gr.Tab(label="推理配置"):
text = gr.Textbox(
label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
)

with gr.Row():
with gr.Tab(label="合成参数"):
gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")

input_mode = gr.Dropdown(
choices=["文本", "自动音素"],
value="文本",
label="输入模式",
)

input_mode = gr.Dropdown(
choices=["手动输入音素/文本", "自动音素转换"],
value="手动输入音素/文本",
label="输入模式",
)
max_new_tokens = gr.Slider(
label="最大生成 Token 数",
minimum=0,
maximum=4096,
value=0, # 0 means no limit
step=8,
)

with gr.Tab(label="语言优先级"):
gr.Markdown("该参数只在自动音素转换时生效.")
top_k = gr.Slider(
label="Top-K", minimum=0, maximum=100, value=0, step=1
)

with gr.Column(scale=1):
language0 = gr.Dropdown(
choices=["中文", "日文", "英文"],
label="语言 1",
value="中文",
top_p = gr.Slider(
label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
)

with gr.Column(scale=1):
language1 = gr.Dropdown(
choices=["中文", "日文", "英文"],
label="语言 2",
value="日文",
repetition_penalty = gr.Slider(
label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
)

with gr.Column(scale=1):
language2 = gr.Dropdown(
choices=["中文", "日文", "英文"],
label="语言 3",
value="英文",
temperature = gr.Slider(
label="温度", minimum=0, maximum=2, value=0.7, step=0.01
)

with gr.Tab(label="参考音频"):
gr.Markdown("3 秒左右的参考音频, 适用于无微调直接推理.")
speaker = gr.Textbox(
label="说话人",
placeholder="说话人",
lines=1,
)

enable_reference_audio = gr.Checkbox(label="启用参考音频", value=False)
reference_audio = gr.Audio(label="参考音频")
reference_text = gr.Textbox(
label="参考文本",
placeholder="参考文本",
lines=1,
value="万一他很崇拜我们呢? 嘿嘿.",
)
with gr.Tab(label="语言优先级"):
gr.Markdown("该参数只在自动音素转换时生效.")

with gr.Column(scale=1):
language0 = gr.Dropdown(
choices=["中文", "日文", "英文"],
label="语言 1",
value="中文",
)

with gr.Column(scale=1):
language1 = gr.Dropdown(
choices=["中文", "日文", "英文"],
label="语言 2",
value="日文",
)

with gr.Column(scale=1):
language2 = gr.Dropdown(
choices=["中文", "日文", "英文"],
label="语言 3",
value="英文",
)

with gr.Tab(label="参考音频"):
gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")

enable_reference_audio = gr.Checkbox(
label="启用参考音频", value=False
)
reference_audio = gr.Audio(
label="参考音频",
value="docs/assets/audios/0_input.wav",
type="filepath",
)
reference_text = gr.Textbox(
label="参考文本",
placeholder="参考文本",
lines=1,
value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
)

with gr.Row():
with gr.Column(scale=2):
generate = gr.Button(value="合成", variant="primary")
with gr.Column(scale=1):
clear = gr.Button(value="清空")
with gr.Row():
with gr.Column(scale=2):
generate = gr.Button(value="合成", variant="primary")
with gr.Column(scale=1):
clear = gr.Button(value="清空")

with gr.Column(scale=3):
error = gr.HTML(label="错误信息")
parsed_text = gr.Dataframe(label="解析结果", headers=["ID", "文本", "语言", "音素"])
audio = gr.Audio(label="合成音频")
parsed_text = gr.Dataframe(
label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
)
audio = gr.Audio(label="合成音频", type="numpy")

# Language & Text Parsing
kwargs = dict(
Expand All @@ -178,7 +384,27 @@ def prepare_text(
enable_reference_audio.change(prepare_text, **kwargs)

# Submit
generate.click(lambda: None, outputs=[audio])
generate.click(
inference,
[
server_url,
text,
input_mode,
language0,
language1,
language2,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
top_k,
top_p,
repetition_penalty,
temperature,
speaker,
],
[audio, error],
)


if __name__ == "__main__":
Expand Down

0 comments on commit 681567d

Please sign in to comment.