From 62eae262c20d5a83db5c825383993d548f82b813 Mon Sep 17 00:00:00 2001 From: Picus303 <135560346+Picus303@users.noreply.github.com> Date: Sat, 7 Dec 2024 07:13:19 +0100 Subject: [PATCH] Make WebUI and API code cleaner (+ 1.5 fixes) (#703) * rename webui.py to run_webui.py * remove unused imports * remove unsued code * move inference code and fix all warnings * move web app code * make code easier to read * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused function * remove msgpack_api.py * rename API files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * finish updating the doc with the new file names * finish updating the doc with the new file names * fix CPU use in the API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor WebUIinference in a class with submodules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-enable streaming in webui inference code * generalize inference code in webui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * make a unique inference engine class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * cleaning code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement new structure of the API (not working) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reimplement chat endpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bug_report.yml | 2 +- docs/en/index.md | 2 +- docs/en/inference.md | 10 +- docs/en/start_agent.md | 2 +- docs/ja/index.md | 2 +- docs/ja/inference.md | 8 +- docs/ja/start_agent.md | 2 +- docs/ko/index.md | 2 +- docs/ko/inference.md | 10 +- docs/ko/start_agent.md | 2 +- docs/pt/index.md | 2 +- docs/pt/inference.md | 8 +- docs/pt/start_agent.md | 2 +- docs/zh/index.md | 2 +- docs/zh/inference.md | 10 +- docs/zh/start_agent.md | 2 +- entrypoint.sh | 2 +- fish_speech/webui/manage.py | 2 +- inference.ipynb | 2 +- start.bat | 2 +- tools/api.py | 951 --------------------- tools/{post_api.py => api_client.py} | 18 +- tools/api_server.py | 98 +++ tools/fish_e2e.py | 4 +- tools/inference_engine/__init__.py | 193 +++++ tools/inference_engine/reference_loader.py | 128 +++ tools/inference_engine/utils.py | 42 + tools/inference_engine/vq_manager.py | 57 ++ tools/msgpack_api.py | 95 -- tools/run_webui.py | 101 +++ tools/schema.py | 38 +- tools/server/agent/__init__.py | 57 ++ tools/server/agent/generate.py | 119 +++ tools/server/agent/generation_utils.py | 122 +++ tools/server/agent/pre_generation_utils.py | 72 ++ tools/server/api_utils.py | 75 ++ tools/server/exception_handler.py | 27 + tools/server/inference.py | 41 + tools/server/model_manager.py | 119 +++ tools/server/model_utils.py | 129 +++ tools/server/views.py | 246 ++++++ tools/webui.py | 570 ------------ tools/webui/__init__.py | 173 ++++ tools/webui/inference.py | 91 ++ tools/webui/variables.py | 14 + 45 files changed, 1959 insertions(+), 1697 deletions(-) delete mode 100644 tools/api.py rename tools/{post_api.py => api_client.py} (91%) create mode 100644 tools/api_server.py create mode 100644 tools/inference_engine/__init__.py create mode 100644 tools/inference_engine/reference_loader.py create mode 100644 tools/inference_engine/utils.py create mode 100644 tools/inference_engine/vq_manager.py delete mode 100644 tools/msgpack_api.py create mode 100644 tools/run_webui.py create mode 100644 tools/server/agent/__init__.py create mode 100644 tools/server/agent/generate.py create mode 100644 tools/server/agent/generation_utils.py create mode 100644 tools/server/agent/pre_generation_utils.py create mode 100644 tools/server/api_utils.py create mode 100644 tools/server/exception_handler.py create mode 100644 tools/server/inference.py create mode 100644 tools/server/model_manager.py create mode 100644 tools/server/model_utils.py create mode 100644 tools/server/views.py delete mode 100644 tools/webui.py create mode 100644 tools/webui/__init__.py create mode 100644 tools/webui/inference.py create mode 100644 tools/webui/variables.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 8dc4ea10..c557f1ff 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -45,7 +45,7 @@ body: description: | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks. placeholder: | - 1. Run the command `python -m tools.post_api -t "xxxxx"` + 1. Run the command `python -m tools.api_client -t "xxxxx"` 2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better) validations: required: true diff --git a/docs/en/index.md b/docs/en/index.md index 667e7fc0..5e5308f5 100644 --- a/docs/en/index.md +++ b/docs/en/index.md @@ -185,7 +185,7 @@ pip install -e .[stable] 4. Configure environment variables and access WebUI In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker. - Then in the terminal inside the docker container, enter `python tools/webui.py` to start the WebUI service. + Then in the terminal inside the docker container, enter `python tools/run_webui.py` to start the WebUI service. If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface. diff --git a/docs/en/inference.md b/docs/en/inference.md index 91f5ce4c..d3e05563 100644 --- a/docs/en/inference.md +++ b/docs/en/inference.md @@ -67,7 +67,7 @@ python tools/vqgan/inference.py \ We provide a HTTP API for inference. You can use the following command to start the server: ```bash -python -m tools.api \ +python -m tools.api_server \ --listen 0.0.0.0:8080 \ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ @@ -78,10 +78,10 @@ python -m tools.api \ After that, you can view and test the API at http://127.0.0.1:8080/. -Below is an example of sending a request using `tools/post_api.py`. +Below is an example of sending a request using `tools/api_client.py`. ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "Text to be input" \ --reference_audio "Path to reference audio" \ --reference_text "Text content of the reference audio" \ @@ -93,7 +93,7 @@ The above command indicates synthesizing the desired audio according to the refe The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command. ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "Text to input" \ --reference_audio "reference audio path1" "reference audio path2" \ --reference_text "reference audio text1" "reference audio text2"\ @@ -109,7 +109,7 @@ The currently supported reference audio has a maximum total duration of 90 secon !!! info - To learn more about available parameters, you can use the command `python -m tools.post_api -h` + To learn more about available parameters, you can use the command `python -m tools.api_client -h` ## GUI Inference [Download client](https://github.com/AnyaCoder/fish-speech-gui/releases) diff --git a/docs/en/start_agent.md b/docs/en/start_agent.md index d2524904..b0e32a30 100644 --- a/docs/en/start_agent.md +++ b/docs/en/start_agent.md @@ -44,7 +44,7 @@ pip install -e .[stable] To build fish-agent, please use the command below under the main folder: ```bash -python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile +python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile ``` The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation. diff --git a/docs/ja/index.md b/docs/ja/index.md index 5f81c5fc..a7889055 100644 --- a/docs/ja/index.md +++ b/docs/ja/index.md @@ -184,7 +184,7 @@ pip install -e .[stable] 4. 環境変数の設定と WebUI へのアクセス Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。 - 次に、Docker コンテナ内のターミナルで `python tools/webui.py` と入力して WebUI サービスを起動します。 + 次に、Docker コンテナ内のターミナルで `python tools/run_webui.py` と入力して WebUI サービスを起動します。 WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。 diff --git a/docs/ja/inference.md b/docs/ja/inference.md index 29476d7c..ed558c9d 100644 --- a/docs/ja/inference.md +++ b/docs/ja/inference.md @@ -67,7 +67,7 @@ python tools/vqgan/inference.py \ 推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます: ```bash -python -m tools.api \ +python -m tools.api_server \ --listen 0.0.0.0:8080 \ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ @@ -78,10 +78,10 @@ python -m tools.api \ その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。 -以下は、`tools/post_api.py` を使用してリクエストを送信する例です。 +以下は、`tools/api_client.py` を使用してリクエストを送信する例です。 ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "入力するテキスト" \ --reference_audio "参照音声へのパス" \ --reference_text "参照音声テキスト" \ @@ -91,7 +91,7 @@ python -m tools.post_api \ 上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。 !!! info - 使用可能なパラメータの詳細については、コマンド` python -m tools.post_api -h `を使用してください + 使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください ## WebUI 推論 diff --git a/docs/ja/start_agent.md b/docs/ja/start_agent.md index 10cabed6..79b1b2c9 100644 --- a/docs/ja/start_agent.md +++ b/docs/ja/start_agent.md @@ -47,7 +47,7 @@ pip install -e .[stable] fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください: ```bash -python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile +python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile ``` `--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。 diff --git a/docs/ko/index.md b/docs/ko/index.md index d2d8dd91..f65974f8 100644 --- a/docs/ko/index.md +++ b/docs/ko/index.md @@ -185,7 +185,7 @@ pip install -e .[stable] 4. 환경 변수 설정 및 WebUI 접근 Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다. - 이후, 터미널에서 `python tools/webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다. + 이후, 터미널에서 `python tools/run_webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다. WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다. diff --git a/docs/ko/inference.md b/docs/ko/inference.md index 201f8a89..9e639c23 100644 --- a/docs/ko/inference.md +++ b/docs/ko/inference.md @@ -67,7 +67,7 @@ python tools/vqgan/inference.py \ 추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다: ```bash -python -m tools.api \ +python -m tools.api_server \ --listen 0.0.0.0:8080 \ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ @@ -78,10 +78,10 @@ python -m tools.api \ 이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다. -아래는 `tools/post_api.py`를 사용하여 요청을 보내는 예시입니다. +아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다. ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "입력할 텍스트" \ --reference_audio "참고 음성 경로" \ --reference_text "참고 음성의 텍스트 내용" \ @@ -93,7 +93,7 @@ python -m tools.post_api \ 다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다. ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "입력할 텍스트" \ --reference_audio "참고 음성 경로1" "참고 음성 경로2" \ --reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\ @@ -107,7 +107,7 @@ python -m tools.post_api \ `--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다. !!! info - 제공되는 파라미터는 `python -m tools.post_api -h`를 사용하여 확인할 수 있습니다. + 제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다. ## GUI 추론 [클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases) diff --git a/docs/ko/start_agent.md b/docs/ko/start_agent.md index dedc7698..c4d085de 100644 --- a/docs/ko/start_agent.md +++ b/docs/ko/start_agent.md @@ -47,7 +47,7 @@ pip install -e .[stable] fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요: ```bash -python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile +python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile ``` `--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다. diff --git a/docs/pt/index.md b/docs/pt/index.md index 60cab972..46fbc376 100644 --- a/docs/pt/index.md +++ b/docs/pt/index.md @@ -181,7 +181,7 @@ pip install -e .[stable] 4. Configure as variáveis de ambiente e acesse a WebUI No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker. - Em seguida, no terminal do contêiner Docker, digite `python tools/webui.py` para iniciar o serviço WebUI. + Em seguida, no terminal do contêiner Docker, digite `python tools/run_webui.py` para iniciar o serviço WebUI. Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI. diff --git a/docs/pt/inference.md b/docs/pt/inference.md index 5223b625..e5d2e802 100644 --- a/docs/pt/inference.md +++ b/docs/pt/inference.md @@ -67,7 +67,7 @@ python tools/vqgan/inference.py \ Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor: ```bash -python -m tools.api \ +python -m tools.api_server \ --listen 0.0.0.0:8080 \ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ @@ -78,10 +78,10 @@ python -m tools.api \ Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/. -Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`. +Abaixo está um exemplo de envio de uma solicitação usando `tools/api_client.py`. ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "Texto a ser inserido" \ --reference_audio "Caminho para o áudio de referência" \ --reference_text "Conteúdo de texto do áudio de referência" \ @@ -91,7 +91,7 @@ python -m tools.post_api \ O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming. !!! info - Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.post_api -h` + Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.api_client -h` ## Inferência por WebUI diff --git a/docs/pt/start_agent.md b/docs/pt/start_agent.md index a17321bc..da6eed54 100644 --- a/docs/pt/start_agent.md +++ b/docs/pt/start_agent.md @@ -47,7 +47,7 @@ pip install -e .[stable] Para construir o fish-agent, use o comando abaixo na pasta principal: ```bash -python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile +python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile ``` O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens. diff --git a/docs/zh/index.md b/docs/zh/index.md index d12835dc..830258bc 100644 --- a/docs/zh/index.md +++ b/docs/zh/index.md @@ -188,7 +188,7 @@ pip install -e .[stable] 4. 配置环境变量,访问 WebUI 在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。 - 接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。 + 接着在 docker 容器内的终端,输入 `python tools/run_webui.py` 即可开启 WebUI 服务。 如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。 diff --git a/docs/zh/inference.md b/docs/zh/inference.md index 6426d218..4106b9bc 100644 --- a/docs/zh/inference.md +++ b/docs/zh/inference.md @@ -73,7 +73,7 @@ python tools/vqgan/inference.py \ 运行以下命令来启动 HTTP 服务: ```bash -python -m tools.api \ +python -m tools.api_server \ --listen 0.0.0.0:8080 \ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ @@ -88,10 +88,10 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...(同上) 随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API. -下面是使用`tools/post_api.py`发送请求的示例。 +下面是使用`tools/api_client.py`发送请求的示例。 ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "要输入的文本" \ --reference_audio "参考音频路径" \ --reference_text "参考音频的文本内容" \ @@ -102,7 +102,7 @@ python -m tools.post_api \ 下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。 ```bash -python -m tools.post_api \ +python -m tools.api_client \ --text "要输入的文本" \ --reference_audio "参考音频路径1" "参考音频路径2" \ --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\ @@ -117,7 +117,7 @@ python -m tools.post_api \ 里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。 !!! info - 要了解有关可用参数的更多信息,可以使用命令`python -m tools.post_api -h` + 要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h` ## GUI 推理 [下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases) diff --git a/docs/zh/start_agent.md b/docs/zh/start_agent.md index 799cfadc..c93b9068 100644 --- a/docs/zh/start_agent.md +++ b/docs/zh/start_agent.md @@ -49,7 +49,7 @@ pip install -e .[stable] 你需要使用以下指令来构建 fish-agent ```bash -python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile +python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile ``` `--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。 diff --git a/entrypoint.sh b/entrypoint.sh index d9e93142..eb4564e0 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -7,4 +7,4 @@ if [ "${CUDA_ENABLED}" != "true" ]; then DEVICE="--device cpu" fi -exec python tools/webui.py ${DEVICE} +exec python tools/run_webui.py ${DEVICE} diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py index c21233ee..1d52a35e 100644 --- a/fish_speech/webui/manage.py +++ b/fish_speech/webui/manage.py @@ -176,7 +176,7 @@ def change_infer( p_infer = subprocess.Popen( [ PYTHON, - "tools/webui.py", + "tools/run_webui.py", "--decoder-checkpoint-path", infer_decoder_model, "--decoder-config-name", diff --git a/inference.ipynb b/inference.ipynb index e690a80d..3bd94ebe 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -83,7 +83,7 @@ }, "outputs": [], "source": [ - "!python tools/webui.py \\\n", + "!python tools/run_webui.py \\\n", " --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n", " --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n", " # --compile" diff --git a/start.bat b/start.bat index 40c7f4d3..c4e27014 100644 --- a/start.bat +++ b/start.bat @@ -82,7 +82,7 @@ if not "!flags!"=="" set "flags=!flags:~1!" echo Debug: flags = !flags! if "!mode!"=="api" ( - %PYTHON_CMD% -m tools.api !flags! + %PYTHON_CMD% -m tools.api_server !flags! ) else if "!mode!"=="infer" ( %PYTHON_CMD% -m tools.webui !flags! ) diff --git a/tools/api.py b/tools/api.py deleted file mode 100644 index 4fae53ce..00000000 --- a/tools/api.py +++ /dev/null @@ -1,951 +0,0 @@ -import io -import json -import os -import queue -import re -import time -import traceback -import wave -from argparse import ArgumentParser -from http import HTTPStatus -from pathlib import Path -from typing import Annotated, Any - -import librosa -import numpy as np -import ormsgpack -import pyrootutils -import soundfile as sf -import torch -import torchaudio -from baize.datastructures import ContentType -from kui.asgi import ( - Body, - FactoryClass, - HTTPException, - HttpRequest, - HttpView, - JSONResponse, - Kui, - OpenAPI, - StreamResponse, - request, -) -from kui.asgi.routing import MultimethodRoutes -from loguru import logger - -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -import struct -from threading import Lock - -import httpx -from cachetools import LRUCache, cached -from funasr import AutoModel -from silero_vad import get_speech_timestamps, load_silero_vad - -from fish_speech.models.text2semantic.llama import BaseModelArgs - -# from fish_speech.models.vqgan.lit_module import VQGAN -from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture -from fish_speech.text.chn_text_norm.text import Text as ChnNormedText - -# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN -from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer -from fish_speech.utils import autocast_exclude_mps, set_seed -from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text -from tools.llama.generate import ( - GenerateRequest, - GenerateResponse, - WrappedGenerateResponse, - launch_thread_safe_queue, - launch_thread_safe_queue_agent, -) -from tools.schema import ( - GLOBAL_NUM_SAMPLES, - ASRPackRequest, - ServeASRRequest, - ServeASRResponse, - ServeASRSegment, - ServeAudioPart, - ServeForwardMessage, - ServeMessage, - ServeRequest, - ServeResponse, - ServeStreamDelta, - ServeStreamResponse, - ServeTextPart, - ServeTimedASRResponse, - ServeTTSRequest, - ServeVQGANDecodeRequest, - ServeVQGANDecodeResponse, - ServeVQGANEncodeRequest, - ServeVQGANEncodeResponse, - ServeVQPart, -) -from tools.vqgan.inference import load_model as load_decoder_model - -global_lock = Lock() - -# Whether to disable keepalive (which is helpful if the server is in the same cluster) -DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true" -async_client = httpx.AsyncClient( - timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None) -) -backends = torchaudio.list_audio_backends() - -if "ffmpeg" in backends: - backend = "ffmpeg" -else: - backend = "soundfile" - - -def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): - buffer = io.BytesIO() - - with wave.open(buffer, "wb") as wav_file: - wav_file.setnchannels(channels) - wav_file.setsampwidth(bit_depth // 8) - wav_file.setframerate(sample_rate) - - wav_header_bytes = buffer.getvalue() - buffer.close() - return wav_header_bytes - - -# Define utils for web server -async def http_execption_handler(exc: HTTPException): - return JSONResponse( - dict( - statusCode=exc.status_code, - message=exc.content, - error=HTTPStatus(exc.status_code).phrase, - ), - exc.status_code, - exc.headers, - ) - - -async def other_exception_handler(exc: "Exception"): - traceback.print_exc() - - status = HTTPStatus.INTERNAL_SERVER_ERROR - return JSONResponse( - dict(statusCode=status, message=str(exc), error=status.phrase), - status, - ) - - -def load_audio(reference_audio, sr): - if len(reference_audio) > 255 or not Path(reference_audio).exists(): - audio_data = reference_audio - reference_audio = io.BytesIO(audio_data) - - waveform, original_sr = torchaudio.load(reference_audio, backend=backend) - - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - if original_sr != sr: - resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) - waveform = resampler(waveform) - - audio = waveform.squeeze().numpy() - return audio - - -def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): - if enable_reference_audio and reference_audio is not None: - # Load audios, and prepare basic info here - reference_audio_content = load_audio( - reference_audio, decoder_model.spec_transform.sample_rate - ) - - audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ - None, None, : - ] - audio_lengths = torch.tensor( - [audios.shape[2]], device=decoder_model.device, dtype=torch.long - ) - logger.info( - f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" - ) - - # VQ Encoder - if isinstance(decoder_model, FireflyArchitecture): - prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] - - logger.info(f"Encoded prompt: {prompt_tokens.shape}") - else: - prompt_tokens = None - logger.info("No reference audio provided") - - return prompt_tokens - - -def decode_vq_tokens( - *, - decoder_model, - codes, -): - feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) - logger.info(f"VQ features: {codes.shape}") - - if isinstance(decoder_model, FireflyArchitecture): - # VQGAN Inference - return decoder_model.decode( - indices=codes[None], - feature_lengths=feature_lengths, - )[0].squeeze() - - raise ValueError(f"Unknown model type: {type(decoder_model)}") - - -routes = MultimethodRoutes(base_class=HttpView) - - -def get_content_type(audio_format): - if audio_format == "wav": - return "audio/wav" - elif audio_format == "flac": - return "audio/flac" - elif audio_format == "mp3": - return "audio/mpeg" - else: - return "application/octet-stream" - - -@torch.no_grad() -@torch.autocast(device_type="cuda", dtype=torch.half) -def batch_encode(model, audios: list[bytes | torch.Tensor]): - audios = [ - ( - torch.from_numpy( - librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] - )[None] - if isinstance(audio, bytes) - else audio - ) - for audio in audios - ] - - # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios): - # raise ValueError("Single audio length is too long (>120s)") - - max_length = max(audio.shape[-1] for audio in audios) - print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") - - lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) - max_length = lengths.max().item() - padded = torch.stack( - [ - torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1])) - for audio in audios - ] - ).to(model.device) - - features, feature_lengths = model.encode(padded, audio_lengths=lengths) - features, feature_lengths = features.cpu(), feature_lengths.cpu() - - return [feature[..., :length] for feature, length in zip(features, feature_lengths)] - - -@cached( - cache=LRUCache(maxsize=10000), - key=lambda model, audios: (model.device, tuple(audios)), -) -def cached_vqgan_batch_encode(model, audios: list[bytes]): - return batch_encode(model, audios) - - -@routes.http.post("/v1/vqgan/encode") -def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): - - start_time = time.time() - tokens = cached_vqgan_batch_encode(decoder_model, payload.audios) - logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") - - return ormsgpack.packb( - ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), - option=ormsgpack.OPT_SERIALIZE_PYDANTIC, - ) - - -@torch.no_grad() -@torch.autocast(device_type="cuda", dtype=torch.half) -def vqgan_decode(model, features): - lengths = torch.tensor( - [feature.shape[-1] for feature in features], device=model.device - ) - max_length = lengths.max().item() - padded = torch.stack( - [ - torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) - for feature in features - ] - ).to(model.device) - - # If bs too large, we do micro batch decode - audios, audio_lengths = [], [] - for i in range(0, padded.shape[0], 8): - audio, audio_length = model.decode( - padded[i : i + 8], feature_lengths=lengths[i : i + 8] - ) - audios.append(audio) - audio_lengths.append(audio_length) - audios = torch.cat(audios, dim=0) - audio_lengths = torch.cat(audio_lengths, dim=0) - audios, audio_lengths = audios.cpu(), audio_lengths.cpu() - - return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] - - -@routes.http.post("/v1/vqgan/decode") -def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): - tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens] - start_time = time.time() - audios = vqgan_decode(decoder_model, tokens) - logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") - audios = [audio.astype(np.float16).tobytes() for audio in audios] - return ormsgpack.packb( - ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC - ) - - -@torch.no_grad() -def batch_asr(model, audios, sr, language="auto"): - resampled_audios = [] - for audio in audios: - audio = torchaudio.functional.resample(audio, sr, 16000) - assert audio.ndim == 1 - resampled_audios.append(audio) - - with global_lock: - res = model.generate( - input=resampled_audios, - batch_size=len(resampled_audios), - language=language, - use_itn=True, - ) - - results = [] - for r, audio in zip(res, audios): - text = r["text"] - text = re.sub(r"<\|.*?\|>", "", text) - duration = len(audio) / sr * 1000 - huge_gap = False - - if "timestamp" in r and len(r["timestamp"]) > 2: - for timestamp_a, timestamp_b in zip( - r["timestamp"][:-1], r["timestamp"][1:] - ): - # If there is a gap of more than 5 seconds, we consider it as a huge gap - if timestamp_b[0] - timestamp_a[1] > 5000: - huge_gap = True - break - - # Doesn't make sense to have a huge gap at the end - if duration - r["timestamp"][-1][1] > 3000: - huge_gap = True - - results.append( - { - "text": text, - "duration": duration, - "huge_gap": huge_gap, - } - ) - - return results - - -@routes.http.post("/v1/asr") -def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]): - start_time = time.time() - audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios] - audios = [torch.from_numpy(audio).float() for audio in audios] - - if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios): - raise HTTPException(status_code=400, detail="Audio length is too long") - - transcriptions = batch_asr( - asr_model, audios=audios, sr=payload.sample_rate, language=payload.language - ) - logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") - - return ormsgpack.packb( - ServeASRResponse(transcriptions=transcriptions), - option=ormsgpack.OPT_SERIALIZE_PYDANTIC, - ) - - -from fish_speech.conversation import Conversation, Message - - -def execute_request( - input_queue: queue.Queue, - tokenizer: FishTokenizer, - config: BaseModelArgs, - request: ServeRequest, - device: str = "cuda:0", -): - - im_end_id = tokenizer.get_token_id(IM_END_TOKEN) - messages = [] - for message in request.messages: - messages.append(message.to_conversation_message()) - - assert len(messages) >= 1, "At least one message is required" - # assert messages[-1].role == "user", "The last message must be from the user" - - if messages[-1].role == "user": - messages.append( - Message(role="assistant", parts=[], add_im_end=False, modality="voice") - ) - elif messages[-1].role == "raw": - messages[-1].add_im_start = False - messages[-1].add_im_end = False - messages[-1].modality = "voice" - else: - assert ( - messages[-1].role == "assistant" - ), "The last message must be from the assistant" - messages[-1].add_im_end = False - - conv = Conversation(messages=messages) - - # conv.visualize(tokenizer) - prompt = conv.encode_for_inference( - tokenizer=tokenizer, num_codebooks=config.num_codebooks - ).to(device) - - if request.streaming: - for i in range(request.num_samples): - yield ServeStreamResponse( - sample_id=i, - delta=ServeStreamDelta( - role="assistant", - ), - ) - - req = { - "prompt": prompt, - "max_new_tokens": request.max_new_tokens, - "im_end_id": im_end_id, - "temperature": request.temperature, - "top_p": request.top_p, - "repetition_penalty": request.repetition_penalty, - "num_samples": request.num_samples, - "early_stop_threshold": request.early_stop_threshold, - } - - start = time.time() - response_queue = queue.Queue() - input_queue.put(GenerateRequest(req, response_queue)) - - # Decoding - decode_buffer = [[] for _ in range(request.num_samples)] - parts = [[] for _ in range(request.num_samples)] - - def send_reset_buffer(sample_id): - nonlocal decode_buffer - if len(decode_buffer[sample_id]) == 0: - return - - decoded = tokenizer.decode(decode_buffer[sample_id]) - part = ServeTextPart(text=decoded) - - if request.streaming: - yield ServeStreamResponse(delta=ServeStreamDelta(part=part)) - else: - parts[sample_id].append(part) - - decode_buffer[sample_id] = [] - - # Decode process - finished = [False for _ in range(request.num_samples)] - stats = {} - idx = 0 - while True: - response = response_queue.get() - - if response in ["stop", "error"]: - break - - for sample_id, tokens in enumerate(response): - if finished[sample_id]: - continue - - if tokens[0] == im_end_id: - finished[sample_id] = True - if request.streaming: - yield from send_reset_buffer(sample_id) - yield ServeStreamResponse( - sample_id=sample_id, - finish_reason="stop", - stats=stats, - ) - continue - - is_semantic = ( - tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id - ) - if is_semantic and request.streaming: - yield from send_reset_buffer(sample_id) - # Streaming vq - _tokens = tokens[1:].clone() - - if config.share_codebook_embeddings is False: - for i in range(len(_tokens)): - _tokens[i] -= config.codebook_size * i - - yield ServeStreamResponse( - sample_id=sample_id, - delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), - ) - continue - - # Not streaming vq - if is_semantic: - yield from send_reset_buffer(sample_id) - # None streaming vq - if len(parts[sample_id]) == 0 or not isinstance( - parts[sample_id][-1], ServeVQPart - ): - _tokens = tokens[1:].clone() - - if config.share_codebook_embeddings is False: - for i in range(len(_tokens)): - _tokens[i] -= config.codebook_size * i - - parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) - else: - for codebook_id, value in enumerate(tokens[1:, :]): - val = value.item() - if config.share_codebook_embeddings is False: - val -= config.codebook_size * codebook_id - - parts[sample_id][-1].codes[codebook_id].append(val) - continue - - if not is_semantic: - # Stream text decode is not supported now - decode_buffer[sample_id].append(tokens[0, 0]) - - if idx == 0: - stats["time_to_first_token"] = (time.time() - start) * 1000 - - idx += 1 - - for sample_id in range(request.num_samples): - yield from send_reset_buffer(sample_id) - - stats["total_time"] = (time.time() - start) * 1000 - stats["total_tokens"] = idx - - if request.streaming: - for sample_id in range(request.num_samples): - if finished[sample_id]: - continue - yield ServeStreamResponse( - finish_reason=response, stats=stats, sample_id=sample_id - ) - return - - yield ServeResponse( - messages=[ - ServeMessage(role="assistant", parts=parts[i]) - for i in range(request.num_samples) - ], - finish_reason=response, - stats=stats, - ) - - -@routes.http.post("/v1/chat") -def api_invoke_chat( - req: Annotated[ServeRequest, Body(exclusive=True)], -): - """ - Invoke model and generate audio - """ - - # This makes torch compile happy - assert ( - req.num_samples == GLOBAL_NUM_SAMPLES - ), f"num_samples must be {GLOBAL_NUM_SAMPLES}" - - content_type = request.headers.get("Content-Type", "application/json") - json_mode = "application/json" in content_type - - async def wrapped_generator(): - generator = execute_request(llama_queue, tokenizer, config, req, args.device) - - for i in generator: - if json_mode: - body = i.model_dump_json().encode("utf-8") - yield b"data: " + body + b"\n\n" - else: - body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) - yield struct.pack("I", len(body)) + body - - # Naive mode - if req.streaming is False: - result = next(execute_request(llama_queue, tokenizer, config, req, args.device)) - - if json_mode: - return JSONResponse(result.model_dump()) - else: - return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) - - return StreamResponse( - iterable=wrapped_generator(), content_type="text/event-stream" - ) - - -@torch.inference_mode() -def inference(req: ServeTTSRequest): - - idstr: str | None = req.reference_id - if idstr is not None: - ref_folder = Path("references") / idstr - ref_folder.mkdir(parents=True, exist_ok=True) - ref_audios = list_files( - ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False - ) - - if req.use_memory_cache == "never" or ( - req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 - ): - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=audio_to_bytes(str(ref_audio)), - enable_reference_audio=True, - ) - for ref_audio in ref_audios - ] - prompt_texts = [ - read_ref_text(str(ref_audio.with_suffix(".lab"))) - for ref_audio in ref_audios - ] - else: - logger.info("Use same references") - - else: - # Parse reference audio aka prompt - refs = req.references - - if req.use_memory_cache == "never" or ( - req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 - ): - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=ref.audio, - enable_reference_audio=True, - ) - for ref in refs - ] - prompt_texts = [ref.text for ref in refs] - else: - logger.info("Use same references") - - if req.seed is not None: - set_seed(req.seed) - logger.warning(f"set seed: {req.seed}") - - # LLAMA Inference - request = dict( - device=decoder_model.device, - max_new_tokens=req.max_new_tokens, - text=( - req.text - if not req.normalize - else ChnNormedText(raw_text=req.text).normalize() - ), - top_p=req.top_p, - repetition_penalty=req.repetition_penalty, - temperature=req.temperature, - compile=args.compile, - iterative_prompt=req.chunk_length > 0, - chunk_length=req.chunk_length, - max_length=4096, - prompt_tokens=prompt_tokens, - prompt_text=prompt_texts, - ) - - response_queue = queue.Queue() - llama_queue.put( - GenerateRequest( - request=request, - response_queue=response_queue, - ) - ) - - if req.streaming: - yield wav_chunk_header() - - segments = [] - while True: - result: WrappedGenerateResponse = response_queue.get() - if result.status == "error": - raise result.response - break - - result: GenerateResponse = result.response - if result.action == "next": - break - - with autocast_exclude_mps( - device_type=decoder_model.device.type, dtype=args.precision - ): - fake_audios = decode_vq_tokens( - decoder_model=decoder_model, - codes=result.codes, - ) - - fake_audios = fake_audios.float().cpu().numpy() - - if req.streaming: - yield (fake_audios * 32768).astype(np.int16).tobytes() - else: - segments.append(fake_audios) - - if req.streaming: - return - - if len(segments) == 0: - raise HTTPException( - HTTPStatus.INTERNAL_SERVER_ERROR, - content="No audio generated, please check the input text.", - ) - - fake_audios = np.concatenate(segments, axis=0) - yield fake_audios - - -async def inference_async(req: ServeTTSRequest): - for chunk in inference(req): - yield chunk - - -async def buffer_to_async_generator(buffer): - yield buffer - - -@routes.http.post("/v1/tts") -async def api_invoke_model( - req: Annotated[ServeTTSRequest, Body(exclusive=True)], -): - """ - Invoke model and generate audio - """ - - if args.max_text_length > 0 and len(req.text) > args.max_text_length: - raise HTTPException( - HTTPStatus.BAD_REQUEST, - content=f"Text is too long, max length is {args.max_text_length}", - ) - - if req.streaming and req.format != "wav": - raise HTTPException( - HTTPStatus.BAD_REQUEST, - content="Streaming only supports WAV format", - ) - - if req.streaming: - return StreamResponse( - iterable=inference_async(req), - headers={ - "Content-Disposition": f"attachment; filename=audio.{req.format}", - }, - content_type=get_content_type(req.format), - ) - else: - fake_audios = next(inference(req)) - buffer = io.BytesIO() - sf.write( - buffer, - fake_audios, - decoder_model.spec_transform.sample_rate, - format=req.format, - ) - - return StreamResponse( - iterable=buffer_to_async_generator(buffer.getvalue()), - headers={ - "Content-Disposition": f"attachment; filename=audio.{req.format}", - }, - content_type=get_content_type(req.format), - ) - - -@routes.http.post("/v1/health") -async def api_health(): - """ - Health check - """ - return JSONResponse({"status": "ok"}) - - -def parse_args(): - parser = ArgumentParser() - parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") - parser.add_argument("--load-asr-model", action="store_true") - parser.add_argument( - "--llama-checkpoint-path", - type=str, - default="checkpoints/fish-speech-1.4", - ) - parser.add_argument( - "--decoder-checkpoint-path", - type=str, - default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - ) - parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--half", action="store_true") - parser.add_argument("--compile", action="store_true") - parser.add_argument("--max-text-length", type=int, default=0) - parser.add_argument("--listen", type=str, default="127.0.0.1:8080") - parser.add_argument("--workers", type=int, default=1) - - return parser.parse_args() - - -# Define Kui app -openapi = OpenAPI( - { - "title": "Fish Speech API", - "version": "1.4.2", - }, -).routes - - -class MsgPackRequest(HttpRequest): - async def data( - self, - ) -> Annotated[ - Any, ContentType("application/msgpack"), ContentType("application/json") - ]: - if self.content_type == "application/msgpack": - return ormsgpack.unpackb(await self.body) - - elif self.content_type == "application/json": - return await self.json - - raise HTTPException( - HTTPStatus.UNSUPPORTED_MEDIA_TYPE, - headers={"Accept": "application/msgpack, application/json"}, - ) - - -app = Kui( - routes=routes + openapi[1:], # Remove the default route - exception_handlers={ - HTTPException: http_execption_handler, - Exception: other_exception_handler, - }, - factory_class=FactoryClass(http=MsgPackRequest), - cors_config={}, -) - - -def load_asr_model(*, device="cuda", hub="ms"): - return AutoModel( - model="iic/SenseVoiceSmall", - device=device, - disable_pbar=True, - hub=hub, - ) - - -# Each worker process created by Uvicorn has its own memory space, -# meaning that models and variables are not shared between processes. -# Therefore, any global variables (like `llama_queue` or `decoder_model`) -# will not be shared across workers. - - -# Multi-threading for deep learning can cause issues, such as inconsistent -# outputs if multiple threads access the same buffers simultaneously. -# Instead, it's better to use multiprocessing or independent models per thread. -@app.on_startup -def initialize_app(app: Kui): - - global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts - - prompt_tokens, prompt_texts = [], [] - - args = parse_args() # args same as ones in other processes - args.precision = torch.half if args.half else torch.bfloat16 - - if args.load_asr_model: - logger.info(f"Loading ASR model...") - asr_model = load_asr_model(device=args.device) - - logger.info("Loading Llama model...") - - if args.mode == "tts": - llama_queue = launch_thread_safe_queue( - checkpoint_path=args.llama_checkpoint_path, - device=args.device, - precision=args.precision, - compile=args.compile, - ) - else: - llama_queue, tokenizer, config = launch_thread_safe_queue_agent( - checkpoint_path=args.llama_checkpoint_path, - device=args.device, - precision=args.precision, - compile=args.compile, - ) - - logger.info("Llama model loaded, loading VQ-GAN model...") - - decoder_model = load_decoder_model( - config_name=args.decoder_config_name, - checkpoint_path=args.decoder_checkpoint_path, - device=args.device, - ) - - logger.info("VQ-GAN model loaded, warming up...") - - vad_model = load_silero_vad() - - logger.info("VAD model loaded, warming up...") - - if args.mode == "tts": - # Dry run to ensure models work and avoid first-time latency - list( - inference( - ServeTTSRequest( - text="Hello world.", - references=[], - reference_id=None, - max_new_tokens=0, - chunk_length=200, - top_p=0.7, - repetition_penalty=1.5, - temperature=0.7, - emotion=None, - format="wav", - ) - ) - ) - - logger.info(f"Warming up done, starting server at http://{args.listen}") - - -if __name__ == "__main__": - - import uvicorn - - args = parse_args() - host, port = args.listen.split(":") - uvicorn.run( - "tools.api:app", - host=host, - port=int(port), - workers=args.workers, - log_level="info", - ) diff --git a/tools/post_api.py b/tools/api_client.py similarity index 91% rename from tools/post_api.py rename to tools/api_client.py index f319d12f..90d7b29b 100644 --- a/tools/post_api.py +++ b/tools/api_client.py @@ -69,10 +69,6 @@ def parse_args(): parser.add_argument( "--format", type=str, choices=["wav", "mp3", "flac"], default="wav" ) - parser.add_argument( - "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz" - ) - parser.add_argument("--opus_bitrate", type=int, default=-1000) parser.add_argument( "--latency", type=str, @@ -112,11 +108,9 @@ def parse_args(): parser.add_argument( "--use_memory_cache", type=str, - default="never", - choices=["on-demand", "never"], - help="Cache encoded references codes in memory.\n" - "If `on-demand`, the server will use cached encodings\n " - "instead of encoding reference audio again.", + default="off", + choices=["on", "off"], + help="Cache encoded references codes in memory.\n", ) parser.add_argument( "--seed", @@ -154,14 +148,14 @@ def parse_args(): data = { "text": args.text, "references": [ - ServeReferenceAudio(audio=ref_audio, text=ref_text) + ServeReferenceAudio( + audio=ref_audio if ref_audio is not None else b"", text=ref_text + ) for ref_text, ref_audio in zip(ref_texts, byte_audios) ], "reference_id": idstr, "normalize": args.normalize, "format": args.format, - "mp3_bitrate": args.mp3_bitrate, - "opus_bitrate": args.opus_bitrate, "max_new_tokens": args.max_new_tokens, "chunk_length": args.chunk_length, "top_p": args.top_p, diff --git a/tools/api_server.py b/tools/api_server.py new file mode 100644 index 00000000..7b5d26fc --- /dev/null +++ b/tools/api_server.py @@ -0,0 +1,98 @@ +from threading import Lock + +import pyrootutils +import uvicorn +from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes +from loguru import logger + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from tools.server.api_utils import MsgPackRequest, parse_args +from tools.server.exception_handler import ExceptionHandler +from tools.server.model_manager import ModelManager +from tools.server.views import ( + ASRView, + ChatView, + HealthView, + TTSView, + VQGANDecodeView, + VQGANEncodeView, +) + + +class API(ExceptionHandler): + def __init__(self): + self.args = parse_args() + self.routes = [ + ("/v1/health", HealthView), + ("/v1/vqgan/encode", VQGANEncodeView), + ("/v1/vqgan/decode", VQGANDecodeView), + ("/v1/asr", ASRView), + ("/v1/tts", TTSView), + ("/v1/chat", ChatView), + ] + self.routes = Routes([HttpRoute(path, view) for path, view in self.routes]) + + self.openapi = OpenAPI( + { + "title": "Fish Speech API", + "version": "1.5.0", + }, + ).routes + + # Initialize the app + self.app = Kui( + routes=self.routes + self.openapi[1:], # Remove the default route + exception_handlers={ + HTTPException: self.http_exception_handler, + Exception: self.other_exception_handler, + }, + factory_class=FactoryClass(http=MsgPackRequest), + cors_config={}, + ) + + # Add the state variables + self.app.state.lock = Lock() + self.app.state.device = self.args.device + self.app.state.max_text_length = self.args.max_text_length + + # Associate the app with the model manager + self.app.on_startup(self.initialize_app) + + async def initialize_app(self, app: Kui): + # Make the ModelManager available to the views + app.state.model_manager = ModelManager( + mode=self.args.mode, + device=self.args.device, + half=self.args.half, + compile=self.args.compile, + asr_enabled=self.args.load_asr_model, + llama_checkpoint_path=self.args.llama_checkpoint_path, + decoder_checkpoint_path=self.args.decoder_checkpoint_path, + decoder_config_name=self.args.decoder_config_name, + ) + + logger.info(f"Startup done, listening server at http://{self.args.listen}") + + +# Each worker process created by Uvicorn has its own memory space, +# meaning that models and variables are not shared between processes. +# Therefore, any variables (like `llama_queue` or `decoder_model`) +# will not be shared across workers. + +# Multi-threading for deep learning can cause issues, such as inconsistent +# outputs if multiple threads access the same buffers simultaneously. +# Instead, it's better to use multiprocessing or independent models per thread. + +if __name__ == "__main__": + + api = API() + host, port = api.args.listen.split(":") + + uvicorn.run( + api.app, + host=host, + port=int(port), + workers=api.args.workers, + log_level="info", + ) diff --git a/tools/fish_e2e.py b/tools/fish_e2e.py index 34b531f0..4a44fca3 100644 --- a/tools/fish_e2e.py +++ b/tools/fish_e2e.py @@ -14,8 +14,8 @@ import soundfile as sf from .schema import ( + ServeChatRequest, ServeMessage, - ServeRequest, ServeTextPart, ServeVQGANDecodeRequest, ServeVQGANEncodeRequest, @@ -163,7 +163,7 @@ async def stream( else: user_codes = None - request = ServeRequest( + request = ServeChatRequest( messages=prev_messages + ( [ diff --git a/tools/inference_engine/__init__.py b/tools/inference_engine/__init__.py new file mode 100644 index 00000000..2eb3396a --- /dev/null +++ b/tools/inference_engine/__init__.py @@ -0,0 +1,193 @@ +import gc +import queue +from typing import Generator + +import numpy as np +import torch +from loguru import logger + +from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText +from fish_speech.utils import autocast_exclude_mps, set_seed +from tools.inference_engine.reference_loader import ReferenceLoader +from tools.inference_engine.utils import InferenceResult, wav_chunk_header +from tools.inference_engine.vq_manager import VQManager +from tools.llama.generate import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, +) +from tools.schema import ServeTTSRequest + + +class TTSInferenceEngine(ReferenceLoader, VQManager): + + def __init__( + self, + llama_queue: queue.Queue, + decoder_model: FireflyArchitecture, + precision: torch.dtype, + compile: bool, + ) -> None: + + super().__init__() + + self.llama_queue = llama_queue + self.decoder_model = decoder_model + self.precision = precision + self.compile = compile + + @torch.inference_mode() + def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]: + """ + Main inference function: + - Loads the reference audio and text. + - Calls the LLAMA model for inference. + - Decodes the VQ tokens to audio. + """ + + ref_id: str | None = req.reference_id + prompt_tokens, prompt_texts = [], [] + # Load the reference audio and text based on id or hash + if ref_id is not None: + prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache) + + elif req.references: + prompt_tokens, prompt_texts = self.load_by_hash( + req.references, req.use_memory_cache + ) + + # Set the random seed if provided + if req.seed is not None: + set_seed(req.seed) + logger.warning(f"set seed: {req.seed}") + + # Get the symbolic tokens from the LLAMA model + response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts) + + # Get the sample rate from the decoder model + sample_rate = self.decoder_model.spec_transform.sample_rate + + # If streaming, send the header + if req.streaming: + yield InferenceResult( + code="header", + audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)), + error=None, + ) + + segments = [] + + while True: + # Get the response from the LLAMA model + wrapped_result: WrappedGenerateResponse = response_queue.get() + if wrapped_result.status == "error": + yield InferenceResult( + code="error", + audio=None, + error=( + wrapped_result.response + if isinstance(wrapped_result.response, Exception) + else Exception("Unknown error") + ), + ) + break + + # Check the response type + if not isinstance(wrapped_result.response, GenerateResponse): + raise TypeError( + "Expected GenerateResponse, got {type(wrapped_result.response).__name__}" + ) + + result: GenerateResponse = wrapped_result.response + if result.action != "next": + segment = self.get_audio_segment(result) + + if req.streaming: # Used only by the API server + yield InferenceResult( + code="segment", + audio=(sample_rate, segment), + error=None, + ) + else: + segments.append(segment) + else: + break + + # Clean up the memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # Edge case: no audio generated + if len(segments) == 0: + yield InferenceResult( + code="error", + audio=None, + error=RuntimeError("No audio generated, please check the input text."), + ) + else: + # Streaming or not, return the final audio + audio = np.concatenate(segments, axis=0) + yield InferenceResult( + code="final", + audio=(sample_rate, audio), + error=None, + ) + + return None + + def send_Llama_request( + self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list + ) -> queue.Queue: + """ + Send a request to the LLAMA model to generate the symbolic tokens. + """ + + # Prepare the request + request = dict( + device=self.decoder_model.device, + max_new_tokens=req.max_new_tokens, + text=( + req.text + if not req.normalize + else ChnNormedText(raw_text=req.text).normalize() + ), + top_p=req.top_p, + repetition_penalty=req.repetition_penalty, + temperature=req.temperature, + compile=self.compile, + iterative_prompt=req.chunk_length > 0, + chunk_length=req.chunk_length, + max_length=4096, + prompt_tokens=prompt_tokens, + prompt_text=prompt_texts, + ) + + # Create a queue to get the response + response_queue = queue.Queue() + + # Send the request to the LLAMA model + self.llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + return response_queue + + def get_audio_segment(self, result: GenerateResponse) -> np.ndarray: + """ + Decode the VQ tokens to audio. + """ + + # Don't use autocast on MPS devices + with autocast_exclude_mps( + device_type=self.decoder_model.device.type, dtype=self.precision + ): + # Decode the symbolic tokens to audio + segment = self.decode_vq_tokens(codes=result.codes) + + # Convert the audio to numpy + return segment.float().cpu().numpy() diff --git a/tools/inference_engine/reference_loader.py b/tools/inference_engine/reference_loader.py new file mode 100644 index 00000000..91232eef --- /dev/null +++ b/tools/inference_engine/reference_loader.py @@ -0,0 +1,128 @@ +import io +from hashlib import sha256 +from pathlib import Path +from typing import Callable, Literal, Tuple + +import torch +import torchaudio +from loguru import logger + +from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture +from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text +from tools.schema import ServeReferenceAudio + + +class ReferenceLoader: + + def __init__(self) -> None: + """ + Component of the TTSInferenceEngine class. + Loads and manages the cache for the reference audio and text. + """ + self.ref_by_id: dict = {} + self.ref_by_hash: dict = {} + + # Make Pylance happy (attribut/method not defined...) + self.decoder_model: FireflyArchitecture + self.encode_reference: Callable + + # Define the torchaudio backend + backends = torchaudio.list_audio_backends() + if "ffmpeg" in backends: + self.backend = "ffmpeg" + else: + self.backend = "soundfile" + + def load_by_id( + self, + id: str, + use_cache: Literal["on", "off"], + ) -> Tuple: + + # Load the references audio and text by id + ref_folder = Path("references") / id + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False + ) + + if use_cache == "off" or id not in self.ref_by_id: + # If the references are not already loaded, encode them + prompt_tokens = [ + self.encode_reference( + decoder_model=self.decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + self.ref_by_id[id] = (prompt_tokens, prompt_texts) + + else: + # Reuse already encoded references + logger.info("Use same references") + prompt_tokens, prompt_texts = self.ref_by_id[id] + + return prompt_tokens, prompt_texts + + def load_by_hash( + self, + references: list[ServeReferenceAudio], + use_cache: Literal["on", "off"], + ) -> Tuple: + + # Load the references audio and text by hash + audio_hashes = [sha256(ref.audio).hexdigest() for ref in references] + + cache_used = False + prompt_tokens, prompt_texts = [], [] + for i, ref in enumerate(references): + if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash: + # If the references are not already loaded, encode them + prompt_tokens.append( + self.encode_reference( + decoder_model=self.decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + ) + prompt_texts.append(ref.text) + self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts) + + else: + # Reuse already encoded references + prompt_text, prompt_token = self.ref_by_hash[audio_hashes[i]] + prompt_texts.append(prompt_text) + prompt_tokens.append(prompt_token) + cache_used = True + + if cache_used: + logger.info("Use same references") + + return prompt_tokens, prompt_texts + + def load_audio(self, reference_audio, sr): + """ + Load the audio data from a file or bytes. + """ + if len(reference_audio) > 255 or not Path(reference_audio).exists(): + audio_data = reference_audio + reference_audio = io.BytesIO(audio_data) + + waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend) + + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if original_sr != sr: + resampler = torchaudio.transforms.Resample( + orig_freq=original_sr, new_freq=sr + ) + waveform = resampler(waveform) + + audio = waveform.squeeze().numpy() + return audio diff --git a/tools/inference_engine/utils.py b/tools/inference_engine/utils.py new file mode 100644 index 00000000..b49e37bf --- /dev/null +++ b/tools/inference_engine/utils.py @@ -0,0 +1,42 @@ +import io +import wave +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +import numpy as np + +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText + + +@dataclass +class InferenceResult: + code: Literal["header", "segment", "error", "final"] + audio: Optional[Tuple[int, np.ndarray]] + error: Optional[Exception] + + +def normalize_text(user_input: str, use_normalization: bool) -> str: + """Normalize user input text if needed.""" + if use_normalization: + return ChnNormedText(raw_text=user_input).normalize() + else: + return user_input + + +def wav_chunk_header( + sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1 +) -> np.ndarray: + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + + # Convert to numpy array + wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8) + + return wav_header diff --git a/tools/inference_engine/vq_manager.py b/tools/inference_engine/vq_manager.py new file mode 100644 index 00000000..07b5cb6d --- /dev/null +++ b/tools/inference_engine/vq_manager.py @@ -0,0 +1,57 @@ +from typing import Callable + +import torch +from loguru import logger + +from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture + + +class VQManager: + + def __init__(self): + # Make Pylance happy (attribut/method not defined...) + self.decoder_model: FireflyArchitecture + self.load_audio: Callable + + def decode_vq_tokens(self, codes): + feature_lengths = torch.tensor( + [codes.shape[1]], device=self.decoder_model.device + ) + logger.info(f"VQ features: {codes.shape}") + + if isinstance(self.decoder_model, FireflyArchitecture): + return self.decoder_model.decode( + indices=codes[None], + feature_lengths=feature_lengths, + )[0].squeeze() + + raise ValueError(f"Unknown model type: {type(self.decoder_model)}") + + def encode_reference(self, reference_audio, enable_reference_audio): + if enable_reference_audio and reference_audio is not None: + # Load audios, and prepare basic info here + reference_audio_content = self.load_audio( + reference_audio, self.decoder_model.spec_transform.sample_rate + ) + + audios = torch.from_numpy(reference_audio_content).to( + self.decoder_model.device + )[None, None, :] + audio_lengths = torch.tensor( + [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long + ) + logger.info( + f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds" + ) + + # VQ Encoder + if isinstance(self.decoder_model, FireflyArchitecture): + prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0] + logger.info(f"Encoded prompt: {prompt_tokens.shape}") + else: + raise ValueError(f"Unknown model type: {type(self.decoder_model)}") + else: + prompt_tokens = None + logger.info("No reference audio provided") + + return prompt_tokens diff --git a/tools/msgpack_api.py b/tools/msgpack_api.py deleted file mode 100644 index 896cccf7..00000000 --- a/tools/msgpack_api.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -from argparse import ArgumentParser -from pathlib import Path - -import httpx -import ormsgpack - -from tools.schema import ServeReferenceAudio, ServeTTSRequest - -api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY") - - -def audio_request(): - # priority: ref_id > references - request = ServeTTSRequest( - text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", - # reference_id="114514", - references=[ - ServeReferenceAudio( - audio=open("lengyue.wav", "rb").read(), - text=open("lengyue.lab", "r", encoding="utf-8").read(), - ) - ], - streaming=True, - ) - - api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY") - - with ( - httpx.Client() as client, - open("hello.wav", "wb") as f, - ): - with client.stream( - "POST", - "http://127.0.0.1:8080/v1/tts", - content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), - headers={ - "authorization": f"Bearer {api_key}", - "content-type": "application/msgpack", - }, - timeout=None, - ) as response: - for chunk in response.iter_bytes(): - f.write(chunk) - - -def asr_request(audio_path: Path): - - # Read the audio file - with open( - str(audio_path), - "rb", - ) as audio_file: - audio_data = audio_file.read() - - # Prepare the request data - request_data = { - "audio": audio_data, - "language": "en", # Optional: specify the language - "ignore_timestamps": False, # Optional: set to True to ignore precise timestamps - } - - # Send the request - with httpx.Client() as client: - response = client.post( - "https://api.fish.audio/v1/asr", - headers={ - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/msgpack", - }, - content=ormsgpack.packb(request_data), - ) - - # Parse the response - result = response.json() - - print(f"Transcribed text: {result['text']}") - print(f"Audio duration: {result['duration']} seconds") - - for segment in result["segments"]: - print(f"Segment: {segment['text']}") - print(f"Start time: {segment['start']}, End time: {segment['end']}") - - -def parse_args(): - parser = ArgumentParser() - parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3") - - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - asr_request(args.audio_path) diff --git a/tools/run_webui.py b/tools/run_webui.py new file mode 100644 index 00000000..6b0ab490 --- /dev/null +++ b/tools/run_webui.py @@ -0,0 +1,101 @@ +import os +from argparse import ArgumentParser +from pathlib import Path + +import pyrootutils +import torch +from loguru import logger + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from tools.inference_engine import TTSInferenceEngine +from tools.llama.generate import launch_thread_safe_queue +from tools.schema import ServeTTSRequest +from tools.vqgan.inference import load_model as load_decoder_model +from tools.webui import build_app +from tools.webui.inference import get_inference_wrapper + +# Make einx happy +os.environ["EINX_FILTER_TRACEBACK"] = "false" + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--llama-checkpoint-path", + type=Path, + default="checkpoints/fish-speech-1.5", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=Path, + default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-gradio-length", type=int, default=0) + parser.add_argument("--theme", type=str, default="light") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + args.precision = torch.half if args.half else torch.bfloat16 + + # Check if CUDA is available + if not torch.cuda.is_available(): + logger.info("CUDA is not available, running on CPU.") + args.device = "cpu" + + logger.info("Loading Llama model...") + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + + logger.info("Loading VQ-GAN model...") + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("Decoder model loaded, warming up...") + + # Create the inference engine + inference_engine = TTSInferenceEngine( + llama_queue=llama_queue, + decoder_model=decoder_model, + compile=args.compile, + precision=args.precision, + ) + + # Dry run to check if the model is loaded correctly and avoid the first-time latency + list( + inference_engine.inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=0, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.5, + temperature=0.7, + format="wav", + ) + ) + ) + + logger.info("Warming up done, launching the web UI...") + + # Get the inference function with the immutable arguments + inference_fct = get_inference_wrapper(inference_engine) + + app = build_app(inference_fct, args.theme) + app.launch(show_api=True) diff --git a/tools/schema.py b/tools/schema.py index c8813fc6..4ce91600 100644 --- a/tools/schema.py +++ b/tools/schema.py @@ -1,16 +1,14 @@ import os import queue from dataclasses import dataclass -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal import torch -from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist +from pydantic import BaseModel, Field, conint, conlist from pydantic.functional_validators import SkipValidation from fish_speech.conversation import Message, TextPart, VQPart -GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1)) - class ServeVQPart(BaseModel): type: Literal["vq"] = "vq" @@ -64,7 +62,7 @@ class ServeASRResponse(BaseModel): class ServeMessage(BaseModel): - role: Literal["system", "assistant", "user", "raw"] + role: Literal["system", "assistant", "user"] parts: list[ServeVQPart | ServeTextPart] def to_conversation_message(self): @@ -85,7 +83,7 @@ def to_conversation_message(self): return new_message -class ServeRequest(BaseModel): +class ServeChatRequest(BaseModel): messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)] max_new_tokens: int = 1024 top_p: float = 0.7 @@ -114,11 +112,6 @@ class ServeVQGANDecodeResponse(BaseModel): audios: list[bytes] -class ServeReferenceAudio(BaseModel): - audio: bytes - text: str - - class ServeForwardMessage(BaseModel): role: str content: str @@ -150,24 +143,11 @@ def __repr__(self) -> str: return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})" -class ServeChatRequestV1(BaseModel): - model: str = "llama3-8b" - messages: list[ServeForwardMessage] = [] - audio: bytes | None = None - temperature: float = 1.0 - top_p: float = 1.0 - max_tokens: int = 256 - voice: str = "jessica" - tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3" - tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128 - - class ServeTTSRequest(BaseModel): text: str chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 # Audio format format: Literal["wav", "pcm", "mp3"] = "wav" - mp3_bitrate: Literal[64, 128, 192] = 128 # References audios for in-context learning references: list[ServeReferenceAudio] = [] # Reference id @@ -175,16 +155,16 @@ class ServeTTSRequest(BaseModel): # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 reference_id: str | None = None seed: int | None = None - use_memory_cache: Literal["on-demand", "never"] = "never" + use_memory_cache: Literal["on", "off"] = "off" # Normalize text for en & zh, this increase stability for numbers normalize: bool = True - mp3_bitrate: Optional[int] = 64 - opus_bitrate: Optional[int] = -1000 - # Balance mode will reduce latency to 300ms, but may decrease stability - latency: Literal["normal", "balanced"] = "normal" # not usually used below streaming: bool = False max_new_tokens: int = 1024 top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 + + class Config: + # Allow arbitrary types for pytorch related types + arbitrary_types_allowed = True diff --git a/tools/server/agent/__init__.py b/tools/server/agent/__init__.py new file mode 100644 index 00000000..a4b0a2ec --- /dev/null +++ b/tools/server/agent/__init__.py @@ -0,0 +1,57 @@ +import struct +from functools import partial + +import ormsgpack + +from tools.server.agent.generate import generate_responses +from tools.server.agent.pre_generation_utils import prepare_messages + + +def execute_request(input_queue, tokenizer, config, request, device): + """ + This function prepares the conversation, encodes the request, + sends the generation request, and handles decoding/streaming. + It returns a response generator (ServeResponse or ServeStreamResponse). + """ + prompt, im_end_id = prepare_messages(request, tokenizer, config) + yield from generate_responses( + input_queue, tokenizer, config, request, prompt, im_end_id, device + ) + + +def response_generator(req, llama_queue, tokenizer, config, device): + """ + Non-streaming response wrapper for the chat endpoint. + Only returns the final result. + """ + generator = execute_request(llama_queue, tokenizer, config, req, device) + return next(generator) + + +async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode): + """ + Streaming response wrapper for the chat endpoint. + Returns the response in chunks. + """ + generator = execute_request(llama_queue, tokenizer, config, req, device) + for i in generator: + if json_mode: + body = i.model_dump_json().encode("utf-8") + yield b"data: " + body + b"\n\n" + else: + body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) + yield struct.pack("I", len(body)) + body + + +def get_response_generator( + llama_queue, tokenizer, config, req, device, json_mode +) -> partial: + """ + Get the correct response generator based on the request. + """ + if not req.streaming: + return partial(response_generator, req, llama_queue, tokenizer, config, device) + else: + return partial( + streaming_generator, req, llama_queue, tokenizer, config, device, json_mode + ) diff --git a/tools/server/agent/generate.py b/tools/server/agent/generate.py new file mode 100644 index 00000000..ef4ae9a7 --- /dev/null +++ b/tools/server/agent/generate.py @@ -0,0 +1,119 @@ +import time + +from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse +from tools.server.agent.generation_utils import ( + initialize_decode_buffers, + process_response_tokens, + send_reset_buffer, +) +from tools.server.agent.pre_generation_utils import ( + create_generation_request, + send_generation_request, +) + + +def generate_responses( + input_queue, tokenizer, config, request, prompt, im_end_id, device +): + """ + Main generation function that handles the conversation, encodes the request, + sends the generation request, and handles decoding/streaming. + It returns a response generator (ServeResponse or ServeStreamResponse). + """ + stats = {} + start = time.time() + stats["start_time"] = start + stats["tokens_count"] = 0 + + # Prepare and send the generation request + req = create_generation_request(prompt, request, im_end_id, device) + response_queue = send_generation_request(input_queue, req) + decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples) + + while True: + response = response_queue.get() + + # Handle abnormal finish or error + if response in ["stop", "error"]: + finish_reason = response + break + + # Process the response tokens + is_first_token = stats["tokens_count"] == 0 + responses = process_response_tokens( + response, + tokenizer, + config, + request, + decode_buffer, + parts, + finished, + im_end_id, + stats, + start, + is_first_token, + ) + + # Yield the responses if streaming + if request.streaming and responses: + for r in responses: + yield r + + stats["tokens_count"] += 1 + + # Check if all samples are finished + if all(finished): + finish_reason = "stop" + break + + # Finalize the response + final_responses = finalize_response( + request, finished, decode_buffer, tokenizer, parts, stats, finish_reason + ) + for fr in final_responses: + yield fr + + +def finalize_response( + request, finished, decode_buffer, tokenizer, parts, stats, finish_reason +): + """ + Finalize the response by sending the remaining text buffers. + """ + responses = [] + + # Send the remaining text buffers + for sample_id in range(request.num_samples): + responses.extend( + send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) + ) + + # Calculate the final stats + stats["total_time"] = (time.time() - stats["start_time"]) * 1000 + stats["total_tokens"] = stats["tokens_count"] + + # If streaming, send the final chunks for each sample + if request.streaming: + for sample_id in range(request.num_samples): + if finished[sample_id]: + continue + responses.append( + ServeStreamResponse( + finish_reason=finish_reason, stats=stats, sample_id=sample_id + ) + ) + else: + # If not streaming, send the full messages for each sample + full_messages = [ + ServeMessage(role="assistant", parts=parts[i]) + for i in range(request.num_samples) + ] + responses.append( + ServeResponse( + messages=full_messages, + finish_reason=finish_reason, + stats=stats, + ) + ) + + return responses diff --git a/tools/server/agent/generation_utils.py b/tools/server/agent/generation_utils.py new file mode 100644 index 00000000..dc2dd4e4 --- /dev/null +++ b/tools/server/agent/generation_utils.py @@ -0,0 +1,122 @@ +import time + +from tools.schema import ( + ServeStreamDelta, + ServeStreamResponse, + ServeTextPart, + ServeVQPart, +) + + +def initialize_decode_buffers(num_samples): + """Initialise the decode buffers for each sample.""" + decode_buffer = [[] for _ in range(num_samples)] + parts = [[] for _ in range(num_samples)] + finished = [False for _ in range(num_samples)] + return decode_buffer, parts, finished + + +def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request): + """Send the remaining text buffer for a sample.""" + if len(decode_buffer[sample_id]) == 0: + return [] + + decoded = tokenizer.decode(decode_buffer[sample_id]) + part = ServeTextPart(text=decoded) + + responses = [] + if request.streaming: + responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part))) + else: + parts[sample_id].append(part) + + decode_buffer[sample_id] = [] + return responses + + +def handle_semantic_tokens(tokens, config, sample_id, parts, request): + """Handle the semantic tokens returned by the model.""" + responses = [] + _tokens = tokens[1:].clone() + + if not config.share_codebook_embeddings: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + # If streaming, send the VQ parts directly + if request.streaming: + responses.append( + ServeStreamResponse( + sample_id=sample_id, + delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), + ) + ) + else: + # If not streaming, accumulate the VQ parts + if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart): + parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) + else: + # Accumulate the codes + for codebook_id, value in enumerate(_tokens): + parts[sample_id][-1].codes[codebook_id].append(value.item()) + + return responses + + +def process_response_tokens( + response, + tokenizer, + config, + request, + decode_buffer, + parts, + finished, + im_end_id, + stats, + start, + is_first_token, +): + """Process the response tokens returned by the model.""" + responses = [] + for sample_id, tokens in enumerate(response): + if finished[sample_id]: + continue + + # End of the conversation + if tokens[0] == im_end_id: + finished[sample_id] = True + # Send the remaining text buffer + responses.extend( + send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) + ) + if request.streaming: + responses.append( + ServeStreamResponse( + sample_id=sample_id, + finish_reason="stop", + stats=stats, + ) + ) + continue + + # Check if the token is semantic + is_semantic = ( + tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id + ) + + if is_semantic: + # Before the semantic tokens, send the remaining text buffer + responses.extend( + send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) + ) + responses.extend( + handle_semantic_tokens(tokens, config, sample_id, parts, request) + ) + else: + # Accumulate the text tokens (not implemented?) + decode_buffer[sample_id].append(tokens[0, 0]) + + if is_first_token: + stats["time_to_first_token"] = (time.time() - start) * 1000 + + return responses diff --git a/tools/server/agent/pre_generation_utils.py b/tools/server/agent/pre_generation_utils.py new file mode 100644 index 00000000..135a72e3 --- /dev/null +++ b/tools/server/agent/pre_generation_utils.py @@ -0,0 +1,72 @@ +import queue + +from fish_speech.conversation import Conversation, Message +from fish_speech.tokenizer import IM_END_TOKEN +from tools.llama.generate import GenerateRequest + + +def prepare_messages(request, tokenizer, config): + """ + Reorganise the provided list of messages into a conversation. + Encode the conversation for inference. + """ + # Convert the messages to ConversationMessage objects + messages = [msg.to_conversation_message() for msg in request.messages] + + if len(messages) < 1: + raise ValueError("At least one message is required") + + # Check the last message to determine the next step + last_role = messages[-1].role + match last_role: + case "user": + # The last message is from the user, ask the assistant to respond with a new message + messages.append( + Message(role="assistant", parts=[], add_im_end=False, modality="voice") + ) + case "raw": + # The last message is raw text, ask the assistant to complete it + messages[-1].add_im_start = False + messages[-1].add_im_end = False + messages[-1].modality = "voice" + case "assistant": + # The last message is from the assistant, ask the assistant to continue + messages[-1].add_im_end = False + case _: + # We expect it to be assistant if not user or raw + raise ValueError("The last message must be from the assistant, user or raw") + + # Create a conversation object and encode it for inference + conv = Conversation(messages=messages) + prompt = conv.encode_for_inference( + tokenizer=tokenizer, num_codebooks=config.num_codebooks + ) + im_end_id = tokenizer.get_token_id(IM_END_TOKEN) + + return prompt, im_end_id + + +def create_generation_request(prompt, request, im_end_id, device): + """ + Convert the request into a dictionary that can be sent to the model for generation. + """ + req = { + "prompt": prompt.to(device), + "max_new_tokens": request.max_new_tokens, + "im_end_id": im_end_id, + "temperature": request.temperature, + "top_p": request.top_p, + "repetition_penalty": request.repetition_penalty, + "num_samples": request.num_samples, + "early_stop_threshold": request.early_stop_threshold, + } + return req + + +def send_generation_request(input_queue, req): + """ + Send the generation request to the model and return a queue to get the response. + """ + response_queue = queue.Queue() + input_queue.put(GenerateRequest(req, response_queue)) + return response_queue diff --git a/tools/server/api_utils.py b/tools/server/api_utils.py new file mode 100644 index 00000000..5cfe4c3a --- /dev/null +++ b/tools/server/api_utils.py @@ -0,0 +1,75 @@ +from argparse import ArgumentParser +from http import HTTPStatus +from typing import Annotated, Any + +import ormsgpack +from baize.datastructures import ContentType +from kui.asgi import HTTPException, HttpRequest + +from tools.inference_engine import TTSInferenceEngine +from tools.schema import ServeTTSRequest +from tools.server.inference import inference_wrapper as inference + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") + parser.add_argument("--load-asr-model", action="store_true") + parser.add_argument( + "--llama-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.5", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-text-length", type=int, default=0) + parser.add_argument("--listen", type=str, default="127.0.0.1:8080") + parser.add_argument("--workers", type=int, default=1) + + return parser.parse_args() + + +class MsgPackRequest(HttpRequest): + async def data( + self, + ) -> Annotated[ + Any, ContentType("application/msgpack"), ContentType("application/json") + ]: + if self.content_type == "application/msgpack": + return ormsgpack.unpackb(await self.body) + + elif self.content_type == "application/json": + return await self.json + + raise HTTPException( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + headers={"Accept": "application/msgpack, application/json"}, + ) + + +async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): + for chunk in inference(req, engine): + if isinstance(chunk, bytes): + yield chunk + + +async def buffer_to_async_generator(buffer): + yield buffer + + +def get_content_type(audio_format): + if audio_format == "wav": + return "audio/wav" + elif audio_format == "flac": + return "audio/flac" + elif audio_format == "mp3": + return "audio/mpeg" + else: + return "application/octet-stream" diff --git a/tools/server/exception_handler.py b/tools/server/exception_handler.py new file mode 100644 index 00000000..07d595fa --- /dev/null +++ b/tools/server/exception_handler.py @@ -0,0 +1,27 @@ +import traceback +from http import HTTPStatus + +from kui.asgi import HTTPException, JSONResponse + + +class ExceptionHandler: + + async def http_exception_handler(self, exc: HTTPException): + return JSONResponse( + dict( + statusCode=exc.status_code, + message=exc.content, + error=HTTPStatus(exc.status_code).phrase, + ), + exc.status_code, + exc.headers, + ) + + async def other_exception_handler(self, exc: Exception): + traceback.print_exc() + + status = HTTPStatus.INTERNAL_SERVER_ERROR + return JSONResponse( + dict(statusCode=status, message=str(exc), error=status.phrase), + status, + ) diff --git a/tools/server/inference.py b/tools/server/inference.py new file mode 100644 index 00000000..d5e95483 --- /dev/null +++ b/tools/server/inference.py @@ -0,0 +1,41 @@ +from http import HTTPStatus + +import numpy as np +from kui.asgi import HTTPException + +from tools.inference_engine import TTSInferenceEngine +from tools.schema import ServeTTSRequest + +AMPLITUDE = 32768 # Needs an explaination + + +def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine): + """ + Wrapper for the inference function. + Used in the API server. + """ + for result in engine.inference(req): + match result.code: + case "header": + if isinstance(result.audio, tuple): + yield result.audio[1] + + case "error": + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content=str(result.error), + ) + + case "segment": + if isinstance(result.audio, tuple): + yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes() + + case "final": + if isinstance(result.audio, tuple): + yield result.audio[1] + return None # Stop the generator + + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content="No audio generated, please check the input text.", + ) diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py new file mode 100644 index 00000000..549ad8d4 --- /dev/null +++ b/tools/server/model_manager.py @@ -0,0 +1,119 @@ +import torch +from funasr import AutoModel +from loguru import logger + +from tools.inference_engine import TTSInferenceEngine +from tools.llama.generate import ( + launch_thread_safe_queue, + launch_thread_safe_queue_agent, +) +from tools.schema import ServeTTSRequest +from tools.server.inference import inference_wrapper as inference +from tools.vqgan.inference import load_model as load_decoder_model + +ASR_MODEL_NAME = "iic/SenseVoiceSmall" + + +class ModelManager: + def __init__( + self, + mode: str, + device: str, + half: bool, + compile: bool, + asr_enabled: bool, + llama_checkpoint_path: str, + decoder_checkpoint_path: str, + decoder_config_name: str, + ) -> None: + + self.mode = mode + self.device = device + self.half = half + self.compile = compile + + self.precision = torch.half if half else torch.bfloat16 + + # Check if CUDA is available + if not torch.cuda.is_available(): + self.device = "cpu" + logger.info("CUDA is not available, running on CPU.") + + # Load the ASR model if enabled + if asr_enabled: + self.load_asr_model(self.device) + + # Load the TTS models + self.load_llama_model( + llama_checkpoint_path, self.device, self.precision, self.compile, self.mode + ) + self.load_decoder_model( + decoder_config_name, decoder_checkpoint_path, self.device + ) + self.tts_inference_engine = TTSInferenceEngine( + llama_queue=self.llama_queue, + decoder_model=self.decoder_model, + precision=self.precision, + compile=self.compile, + ) + + # Warm up the models + if self.mode == "tts": + self.warm_up(self.tts_inference_engine) + + def load_asr_model(self, device, hub="ms") -> None: + self.asr_model = AutoModel( + model=ASR_MODEL_NAME, + device=device, + disable_pbar=True, + hub=hub, + ) + logger.info("ASR model loaded.") + + def load_llama_model( + self, checkpoint_path, device, precision, compile, mode + ) -> None: + + if mode == "tts": + self.llama_queue = launch_thread_safe_queue( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=compile, + ) + elif mode == "agent": + self.llama_queue, self.tokenizer, self.config = ( + launch_thread_safe_queue_agent( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=compile, + ) + ) + else: + raise ValueError(f"Invalid mode: {mode}") + + logger.info("LLAMA model loaded.") + + def load_decoder_model(self, config_name, checkpoint_path, device) -> None: + self.decoder_model = load_decoder_model( + config_name=config_name, + checkpoint_path=checkpoint_path, + device=device, + ) + logger.info("Decoder model loaded.") + + def warm_up(self, tts_inference_engine) -> None: + request = ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=0, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.5, + temperature=0.7, + format="wav", + ) + list(inference(request, tts_inference_engine)) + logger.info("Models warmed up.") diff --git a/tools/server/model_utils.py b/tools/server/model_utils.py new file mode 100644 index 00000000..c5a4c3a0 --- /dev/null +++ b/tools/server/model_utils.py @@ -0,0 +1,129 @@ +import io +import re + +import librosa +import torch +import torchaudio +from cachetools import LRUCache, cached + +CACHE_MAXSIZE = 10000 +MICRO_BATCH_SIZE = 8 +ASR_SAMPLE_RATE = 16000 +HUGE_GAP_THRESHOLD = 4000 + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def batch_encode(model, audios_list: list[bytes]): + audios: list[torch.Tensor] = [ + ( + torch.from_numpy( + librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] + )[None] + if isinstance(audio, bytes) + else audio + ) + for audio in audios_list + ] + + lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) + max_length = lengths.max().item() + + print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") + + padded = torch.stack( + [ + torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1]))) + for audio in audios + ] + ).to(model.device) + + features, feature_lengths = model.encode(padded, audio_lengths=lengths) + features, feature_lengths = features.cpu(), feature_lengths.cpu() + + return [feature[..., :length] for feature, length in zip(features, feature_lengths)] + + +@cached( + cache=LRUCache(maxsize=CACHE_MAXSIZE), + key=lambda model, audios: (model.device, tuple(audios)), +) +def cached_vqgan_batch_encode(model, audios: list[bytes]): + return batch_encode(model, audios) + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def vqgan_decode(model, features): + lengths = torch.tensor( + [feature.shape[-1] for feature in features], device=model.device + ) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) + for feature in features + ] + ).to(model.device) + + # If bs too large, we do micro batch decode + audios, audio_lengths = [], [] + for i in range(0, padded.shape[0], MICRO_BATCH_SIZE): + audio, audio_length = model.decode( + padded[i : i + MICRO_BATCH_SIZE], + feature_lengths=lengths[i : i + MICRO_BATCH_SIZE], + ) + audios.append(audio) + audio_lengths.append(audio_length) + audios = torch.cat(audios, dim=0) + audio_lengths = torch.cat(audio_lengths, dim=0) + audios, audio_lengths = audios.cpu(), audio_lengths.cpu() + + return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] + + +@torch.no_grad() +def batch_asr(model, lock, audios, sr, language="auto"): + resampled_audios = [] + for audio in audios: + audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE) + assert audio.ndim == 1 + resampled_audios.append(audio) + + with lock: + res = model.generate( + input=resampled_audios, + batch_size=len(resampled_audios), + language=language, + use_itn=True, + ) + + results = [] + for r, audio in zip(res, audios): + text = r["text"] + text = re.sub(r"<\|.*?\|>", "", text) + duration = len(audio) / sr * 1000 + huge_gap = False + + if "timestamp" in r and len(r["timestamp"]) > 2: + for timestamp_a, timestamp_b in zip( + r["timestamp"][:-1], r["timestamp"][1:] + ): + # If there is a gap of more than 4 seconds, we consider it as a huge gap + if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD: + huge_gap = True + break + + # Doesn't make sense to have a huge gap at the end + if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD: + huge_gap = True + + results.append( + { + "text": text, + "duration": duration, + "huge_gap": huge_gap, + } + ) + + return results diff --git a/tools/server/views.py b/tools/server/views.py new file mode 100644 index 00000000..5f54fa0c --- /dev/null +++ b/tools/server/views.py @@ -0,0 +1,246 @@ +import io +import os +import time +from http import HTTPStatus + +import numpy as np +import ormsgpack +import soundfile as sf +import torch +from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request +from loguru import logger + +from tools.schema import ( + ServeASRRequest, + ServeASRResponse, + ServeChatRequest, + ServeTTSRequest, + ServeVQGANDecodeRequest, + ServeVQGANDecodeResponse, + ServeVQGANEncodeRequest, + ServeVQGANEncodeResponse, +) +from tools.server.agent import get_response_generator +from tools.server.api_utils import ( + buffer_to_async_generator, + get_content_type, + inference_async, +) +from tools.server.inference import inference_wrapper as inference +from tools.server.model_manager import ModelManager +from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode + +MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1)) + + +class HealthView(HttpView): + """ + Return the health status of the server. + """ + + @classmethod + async def post(cls): + return JSONResponse({"status": "ok"}) + + +class VQGANEncodeView(HttpView): + """ + Encode the audio into symbolic tokens. + """ + + @classmethod + async def post(cls): + # Decode the request + payload = await request.data() + req = ServeVQGANEncodeRequest(**payload) + + # Get the model from the app + model_manager: ModelManager = request.app.state.model_manager + decoder_model = model_manager.decoder_model + + # Encode the audio + start_time = time.time() + tokens = cached_vqgan_batch_encode(decoder_model, req.audios) + logger.info( + f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms" + ) + + # Return the response + return ormsgpack.packb( + ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +class VQGANDecodeView(HttpView): + """ + Decode the symbolic tokens into audio. + """ + + @classmethod + async def post(cls): + # Decode the request + payload = await request.data() + req = ServeVQGANDecodeRequest(**payload) + + # Get the model from the app + model_manager: ModelManager = request.app.state.model_manager + decoder_model = model_manager.decoder_model + + # Decode the audio + tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens] + start_time = time.time() + audios = vqgan_decode(decoder_model, tokens) + logger.info( + f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms" + ) + audios = [audio.astype(np.float16).tobytes() for audio in audios] + + # Return the response + return ormsgpack.packb( + ServeVQGANDecodeResponse(audios=audios), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +class ASRView(HttpView): + """ + Perform automatic speech recognition on the audio. + """ + + @classmethod + async def post(cls): + # Decode the request + payload = await request.data() + req = ServeASRRequest(**payload) + + # Get the model from the app + model_manager: ModelManager = request.app.state.model_manager + asr_model = model_manager.asr_model + lock = request.app.state.lock + + # Perform ASR + start_time = time.time() + audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios] + audios = [torch.from_numpy(audio).float() for audio in audios] + + if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios): + raise HTTPException(status_code=400, content="Audio length is too long") + + transcriptions = batch_asr( + asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language + ) + logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") + + # Return the response + return ormsgpack.packb( + ServeASRResponse(transcriptions=transcriptions), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +class TTSView(HttpView): + """ + Perform text-to-speech on the input text. + """ + + @classmethod + async def post(cls): + # Decode the request + payload = await request.data() + req = ServeTTSRequest(**payload) + + # Get the model from the app + app_state = request.app.state + model_manager: ModelManager = app_state.model_manager + engine = model_manager.tts_inference_engine + sample_rate = engine.decoder_model.spec_transform.sample_rate + + # Check if the text is too long + if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Text is too long, max length is {app_state.max_text_length}", + ) + + # Check if streaming is enabled + if req.streaming and req.format != "wav": + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content="Streaming only supports WAV format", + ) + + # Perform TTS + if req.streaming: + return StreamResponse( + iterable=inference_async(req, engine), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + else: + fake_audios = next(inference(req, engine)) + buffer = io.BytesIO() + sf.write( + buffer, + fake_audios, + sample_rate, + format=req.format, + ) + + return StreamResponse( + iterable=buffer_to_async_generator(buffer.getvalue()), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + + +class ChatView(HttpView): + """ + Perform chatbot inference on the input text. + """ + + @classmethod + async def post(cls): + # Decode the request + payload = await request.data() + req = ServeChatRequest(**payload) + + # Check that the number of samples requested is correct + if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}", + ) + + # Get the type of content provided + content_type = request.headers.get("Content-Type", "application/json") + json_mode = "application/json" in content_type + + # Get the models from the app + model_manager: ModelManager = request.app.state.model_manager + llama_queue = model_manager.llama_queue + tokenizer = model_manager.tokenizer + config = model_manager.config + + device = request.app.state.device + + # Get the response generators + response_generator = get_response_generator( + llama_queue, tokenizer, config, req, device, json_mode + ) + + # Return the response in the correct format + if req.streaming is False: + result = response_generator() + if json_mode: + return JSONResponse(result.model_dump()) + else: + return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) + + return StreamResponse( + iterable=response_generator(), content_type="text/event-stream" + ) diff --git a/tools/webui.py b/tools/webui.py deleted file mode 100644 index c2685f67..00000000 --- a/tools/webui.py +++ /dev/null @@ -1,570 +0,0 @@ -import gc -import html -import io -import os -import queue -import wave -from argparse import ArgumentParser -from functools import partial -from pathlib import Path - -import gradio as gr -import librosa -import numpy as np -import pyrootutils -import torch -from loguru import logger -from transformers import AutoTokenizer - -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) - - -from fish_speech.i18n import i18n -from fish_speech.text.chn_text_norm.text import Text as ChnNormedText -from fish_speech.utils import autocast_exclude_mps, set_seed -from tools.api import decode_vq_tokens, encode_reference -from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text -from tools.llama.generate import ( - GenerateRequest, - GenerateResponse, - WrappedGenerateResponse, - launch_thread_safe_queue, -) -from tools.schema import ( - GLOBAL_NUM_SAMPLES, - ASRPackRequest, - ServeASRRequest, - ServeASRResponse, - ServeASRSegment, - ServeAudioPart, - ServeForwardMessage, - ServeMessage, - ServeReferenceAudio, - ServeRequest, - ServeResponse, - ServeStreamDelta, - ServeStreamResponse, - ServeTextPart, - ServeTimedASRResponse, - ServeTTSRequest, - ServeVQGANDecodeRequest, - ServeVQGANDecodeResponse, - ServeVQGANEncodeRequest, - ServeVQGANEncodeResponse, - ServeVQPart, -) -from tools.vqgan.inference import load_model as load_decoder_model - -# Make einx happy -os.environ["EINX_FILTER_TRACEBACK"] = "false" - - -HEADER_MD = f"""# Fish Speech - -{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} - -{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")} - -{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} - -{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} -""" - -TEXTBOX_PLACEHOLDER = i18n("Put your text here.") -SPACE_IMPORTED = False - - -def build_html_error_message(error): - return f""" -
- {html.escape(str(error))} -
- """ - - -@torch.inference_mode() -def inference(req: ServeTTSRequest): - - idstr: str | None = req.reference_id - prompt_tokens, prompt_texts = [], [] - if idstr is not None: - ref_folder = Path("references") / idstr - ref_folder.mkdir(parents=True, exist_ok=True) - ref_audios = list_files( - ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False - ) - - if req.use_memory_cache == "never" or ( - req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 - ): - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=audio_to_bytes(str(ref_audio)), - enable_reference_audio=True, - ) - for ref_audio in ref_audios - ] - prompt_texts = [ - read_ref_text(str(ref_audio.with_suffix(".lab"))) - for ref_audio in ref_audios - ] - else: - logger.info("Use same references") - - else: - # Parse reference audio aka prompt - refs = req.references - - if req.use_memory_cache == "never" or ( - req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 - ): - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=ref.audio, - enable_reference_audio=True, - ) - for ref in refs - ] - prompt_texts = [ref.text for ref in refs] - else: - logger.info("Use same references") - - if req.seed is not None: - set_seed(req.seed) - logger.warning(f"set seed: {req.seed}") - - # LLAMA Inference - request = dict( - device=decoder_model.device, - max_new_tokens=req.max_new_tokens, - text=( - req.text - if not req.normalize - else ChnNormedText(raw_text=req.text).normalize() - ), - top_p=req.top_p, - repetition_penalty=req.repetition_penalty, - temperature=req.temperature, - compile=args.compile, - iterative_prompt=req.chunk_length > 0, - chunk_length=req.chunk_length, - max_length=4096, - prompt_tokens=prompt_tokens, - prompt_text=prompt_texts, - ) - - response_queue = queue.Queue() - llama_queue.put( - GenerateRequest( - request=request, - response_queue=response_queue, - ) - ) - - segments = [] - - while True: - result: WrappedGenerateResponse = response_queue.get() - if result.status == "error": - yield None, None, build_html_error_message(result.response) - break - - result: GenerateResponse = result.response - if result.action == "next": - break - - with autocast_exclude_mps( - device_type=decoder_model.device.type, dtype=args.precision - ): - fake_audios = decode_vq_tokens( - decoder_model=decoder_model, - codes=result.codes, - ) - - fake_audios = fake_audios.float().cpu().numpy() - segments.append(fake_audios) - - if len(segments) == 0: - return ( - None, - None, - build_html_error_message( - i18n("No audio generated, please check the input text.") - ), - ) - - # No matter streaming or not, we need to return the final audio - audio = np.concatenate(segments, axis=0) - yield None, (decoder_model.spec_transform.sample_rate, audio), None - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - -n_audios = 4 - -global_audio_list = [] -global_error_list = [] - - -def inference_wrapper( - text, - enable_reference_audio, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - seed, - batch_infer_num, -): - audios = [] - errors = [] - - for _ in range(batch_infer_num): - result = inference( - text, - enable_reference_audio, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - seed, - ) - - _, audio_data, error_message = next(result) - - audios.append( - gr.Audio(value=audio_data if audio_data else None, visible=True), - ) - errors.append( - gr.HTML(value=error_message if error_message else None, visible=True), - ) - - for _ in range(batch_infer_num, n_audios): - audios.append( - gr.Audio(value=None, visible=False), - ) - errors.append( - gr.HTML(value=None, visible=False), - ) - - return None, *audios, *errors - - -def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): - buffer = io.BytesIO() - - with wave.open(buffer, "wb") as wav_file: - wav_file.setnchannels(channels) - wav_file.setsampwidth(bit_depth // 8) - wav_file.setframerate(sample_rate) - - wav_header_bytes = buffer.getvalue() - buffer.close() - return wav_header_bytes - - -def normalize_text(user_input, use_normalization): - if use_normalization: - return ChnNormedText(raw_text=user_input).normalize() - else: - return user_input - - -def update_examples(): - examples_dir = Path("references") - examples_dir.mkdir(parents=True, exist_ok=True) - example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True) - return gr.Dropdown(choices=example_audios + [""]) - - -def build_app(): - with gr.Blocks(theme=gr.themes.Base()) as app: - gr.Markdown(HEADER_MD) - - # Use light theme by default - app.load( - None, - None, - js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" - % args.theme, - ) - - # Inference - with gr.Row(): - with gr.Column(scale=3): - text = gr.Textbox( - label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 - ) - refined_text = gr.Textbox( - label=i18n("Realtime Transform Text"), - placeholder=i18n( - "Normalization Result Preview (Currently Only Chinese)" - ), - lines=5, - interactive=False, - ) - - with gr.Row(): - normalize = gr.Checkbox( - label=i18n("Text Normalization"), - value=False, - ) - - with gr.Row(): - with gr.Column(): - with gr.Tab(label=i18n("Advanced Config")): - with gr.Row(): - chunk_length = gr.Slider( - label=i18n("Iterative Prompt Length, 0 means off"), - minimum=0, - maximum=300, - value=200, - step=8, - ) - - max_new_tokens = gr.Slider( - label=i18n( - "Maximum tokens per batch, 0 means no limit" - ), - minimum=0, - maximum=2048, - value=0, - step=8, - ) - - with gr.Row(): - top_p = gr.Slider( - label="Top-P", - minimum=0.6, - maximum=0.9, - value=0.7, - step=0.01, - ) - - repetition_penalty = gr.Slider( - label=i18n("Repetition Penalty"), - minimum=1, - maximum=1.5, - value=1.2, - step=0.01, - ) - - with gr.Row(): - temperature = gr.Slider( - label="Temperature", - minimum=0.6, - maximum=0.9, - value=0.7, - step=0.01, - ) - seed = gr.Number( - label="Seed", - info="0 means randomized inference, otherwise deterministic", - value=0, - ) - - with gr.Tab(label=i18n("Reference Audio")): - with gr.Row(): - gr.Markdown( - i18n( - "5 to 10 seconds of reference audio, useful for specifying speaker." - ) - ) - with gr.Row(): - reference_id = gr.Textbox( - label=i18n("Reference ID"), - placeholder="Leave empty to use uploaded references", - ) - - with gr.Row(): - use_memory_cache = gr.Radio( - label=i18n("Use Memory Cache"), - choices=["never", "on-demand", "always"], - value="on-demand", - ) - - with gr.Row(): - reference_audio = gr.Audio( - label=i18n("Reference Audio"), - type="filepath", - ) - with gr.Row(): - reference_text = gr.Textbox( - label=i18n("Reference Text"), - lines=1, - placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", - value="", - ) - - with gr.Column(scale=3): - with gr.Row(): - error = gr.HTML( - label=i18n("Error Message"), - visible=True, - ) - with gr.Row(): - audio = gr.Audio( - label=i18n("Generated Audio"), - type="numpy", - interactive=False, - visible=True, - ) - - with gr.Row(): - with gr.Column(scale=3): - generate = gr.Button( - value="\U0001F3A7 " + i18n("Generate"), variant="primary" - ) - - text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]) - - def inference_wrapper( - text, - normalize, - reference_id, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - seed, - use_memory_cache, - ): - references = [] - if reference_audio: - # 将文件路径转换为字节 - with open(reference_audio, "rb") as audio_file: - audio_bytes = audio_file.read() - references = [ - ServeReferenceAudio(audio=audio_bytes, text=reference_text) - ] - - req = ServeTTSRequest( - text=text, - normalize=normalize, - reference_id=reference_id if reference_id else None, - references=references, - max_new_tokens=max_new_tokens, - chunk_length=chunk_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - seed=int(seed) if seed else None, - use_memory_cache=use_memory_cache, - ) - - for result in inference(req): - if result[2]: # Error message - return None, result[2] - elif result[1]: # Audio data - return result[1], None - - return None, i18n("No audio generated") - - # Submit - generate.click( - inference_wrapper, - [ - refined_text, - normalize, - reference_id, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - seed, - use_memory_cache, - ], - [audio, error], - concurrency_limit=1, - ) - - return app - - -def parse_args(): - parser = ArgumentParser() - parser.add_argument( - "--llama-checkpoint-path", - type=Path, - default="checkpoints/fish-speech-1.5", - ) - parser.add_argument( - "--decoder-checkpoint-path", - type=Path, - default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - ) - parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--half", action="store_true") - parser.add_argument("--compile", action="store_true") - parser.add_argument("--max-gradio-length", type=int, default=0) - parser.add_argument("--theme", type=str, default="light") - - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - args.precision = torch.half if args.half else torch.bfloat16 - - # Check if CUDA is available - if not torch.cuda.is_available(): - logger.info("CUDA is not available, running on CPU.") - args.device = "cpu" - - logger.info("Loading Llama model...") - llama_queue = launch_thread_safe_queue( - checkpoint_path=args.llama_checkpoint_path, - device=args.device, - precision=args.precision, - compile=args.compile, - ) - logger.info("Llama model loaded, loading VQ-GAN model...") - - decoder_model = load_decoder_model( - config_name=args.decoder_config_name, - checkpoint_path=args.decoder_checkpoint_path, - device=args.device, - ) - - logger.info("Decoder model loaded, warming up...") - - # Dry run to check if the model is loaded correctly and avoid the first-time latency - list( - inference( - ServeTTSRequest( - text="Hello world.", - references=[], - reference_id=None, - max_new_tokens=0, - chunk_length=200, - top_p=0.7, - repetition_penalty=1.5, - temperature=0.7, - emotion=None, - format="wav", - ) - ) - ) - - logger.info("Warming up done, launching the web UI...") - - app = build_app() - app.launch(show_api=True) diff --git a/tools/webui/__init__.py b/tools/webui/__init__.py new file mode 100644 index 00000000..78dbc343 --- /dev/null +++ b/tools/webui/__init__.py @@ -0,0 +1,173 @@ +from typing import Callable + +import gradio as gr + +from fish_speech.i18n import i18n +from tools.inference_engine.utils import normalize_text +from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER + + +def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: + with gr.Blocks(theme=gr.themes.Base()) as app: + gr.Markdown(HEADER_MD) + + # Use light theme by default + app.load( + None, + None, + js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" + % theme, + ) + + # Inference + with gr.Row(): + with gr.Column(scale=3): + text = gr.Textbox( + label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 + ) + refined_text = gr.Textbox( + label=i18n("Realtime Transform Text"), + placeholder=i18n( + "Normalization Result Preview (Currently Only Chinese)" + ), + lines=5, + interactive=False, + ) + + with gr.Row(): + normalize = gr.Checkbox( + label=i18n("Text Normalization"), + value=False, + ) + + with gr.Row(): + with gr.Column(): + with gr.Tab(label=i18n("Advanced Config")): + with gr.Row(): + chunk_length = gr.Slider( + label=i18n("Iterative Prompt Length, 0 means off"), + minimum=0, + maximum=300, + value=200, + step=8, + ) + + max_new_tokens = gr.Slider( + label=i18n( + "Maximum tokens per batch, 0 means no limit" + ), + minimum=0, + maximum=2048, + value=0, + step=8, + ) + + with gr.Row(): + top_p = gr.Slider( + label="Top-P", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + + repetition_penalty = gr.Slider( + label=i18n("Repetition Penalty"), + minimum=1, + maximum=1.5, + value=1.2, + step=0.01, + ) + + with gr.Row(): + temperature = gr.Slider( + label="Temperature", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + seed = gr.Number( + label="Seed", + info="0 means randomized inference, otherwise deterministic", + value=0, + ) + + with gr.Tab(label=i18n("Reference Audio")): + with gr.Row(): + gr.Markdown( + i18n( + "5 to 10 seconds of reference audio, useful for specifying speaker." + ) + ) + with gr.Row(): + reference_id = gr.Textbox( + label=i18n("Reference ID"), + placeholder="Leave empty to use uploaded references", + ) + + with gr.Row(): + use_memory_cache = gr.Radio( + label=i18n("Use Memory Cache"), + choices=["on", "off"], + value="on", + ) + + with gr.Row(): + reference_audio = gr.Audio( + label=i18n("Reference Audio"), + type="filepath", + ) + with gr.Row(): + reference_text = gr.Textbox( + label=i18n("Reference Text"), + lines=1, + placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + value="", + ) + + with gr.Column(scale=3): + with gr.Row(): + error = gr.HTML( + label=i18n("Error Message"), + visible=True, + ) + with gr.Row(): + audio = gr.Audio( + label=i18n("Generated Audio"), + type="numpy", + interactive=False, + visible=True, + ) + + with gr.Row(): + with gr.Column(scale=3): + generate = gr.Button( + value="\U0001F3A7 " + i18n("Generate"), + variant="primary", + ) + + text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]) + + # Submit + generate.click( + inference_fct, + [ + refined_text, + normalize, + reference_id, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + seed, + use_memory_cache, + ], + [audio, error], + concurrency_limit=1, + ) + + return app diff --git a/tools/webui/inference.py b/tools/webui/inference.py new file mode 100644 index 00000000..ea3553be --- /dev/null +++ b/tools/webui/inference.py @@ -0,0 +1,91 @@ +import html +from functools import partial +from typing import Any, Callable + +from fish_speech.i18n import i18n +from tools.schema import ServeReferenceAudio, ServeTTSRequest + + +def inference_wrapper( + text, + normalize, + reference_id, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + seed, + use_memory_cache, + engine, +): + """ + Wrapper for the inference function. + Used in the Gradio interface. + """ + + if reference_audio: + references = get_reference_audio(reference_audio, reference_text) + else: + references = [] + + req = ServeTTSRequest( + text=text, + normalize=normalize, + reference_id=reference_id if reference_id else None, + references=references, + max_new_tokens=max_new_tokens, + chunk_length=chunk_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + seed=int(seed) if seed else None, + use_memory_cache=use_memory_cache, + ) + + for result in engine.inference(req): + match result.code: + case "final": + return result.audio, None + case "error": + return None, build_html_error_message(i18n(result.error)) + case _: + pass + + return None, i18n("No audio generated") + + +def get_reference_audio(reference_audio: str, reference_text: str) -> list: + """ + Get the reference audio bytes. + """ + + with open(reference_audio, "rb") as audio_file: + audio_bytes = audio_file.read() + + return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)] + + +def build_html_error_message(error: Any) -> str: + + error = error if isinstance(error, Exception) else Exception("Unknown error") + + return f""" +
+ {html.escape(str(error))} +
+ """ + + +def get_inference_wrapper(engine) -> Callable: + """ + Get the inference function with the immutable arguments. + """ + + return partial( + inference_wrapper, + engine=engine, + ) diff --git a/tools/webui/variables.py b/tools/webui/variables.py new file mode 100644 index 00000000..db42d5d7 --- /dev/null +++ b/tools/webui/variables.py @@ -0,0 +1,14 @@ +from fish_speech.i18n import i18n + +HEADER_MD = f"""# Fish Speech + +{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} + +{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")} + +{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} + +{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} +""" + +TEXTBOX_PLACEHOLDER = i18n("Put your text here.")