diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 18dedd27..8cb459c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,6 @@ repos: - id: check-yaml - id: check-json - id: mixed-line-ending - args: ['--fix=lf'] + args: ["--fix=lf"] - id: check-added-large-files - args: ['--maxkb=5000'] + args: ["--maxkb=5000"] diff --git a/docs/en/finetune.md b/docs/en/finetune.md index 8b19a8df..bf04086b 100644 --- a/docs/en/finetune.md +++ b/docs/en/finetune.md @@ -39,7 +39,7 @@ You need to convert your dataset into the above format and place it under `data` Make sure you have downloaded the VQGAN weights. If not, run the following command: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` You can then run the following command to extract semantic tokens: @@ -48,7 +48,7 @@ You can then run the following command to extract semantic tokens: python tools/vqgan/extract_vq.py data \ --num-workers 1 --batch-size 16 \ --config-name "firefly_gan_vq" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` !!! note @@ -92,7 +92,7 @@ After the command finishes executing, you should see the `quantized-dataset-ft.p Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` Finally, you can start the fine-tuning by running the following command: @@ -120,9 +120,9 @@ After training, you need to convert the LoRA weights to regular weights before p ```bash python tools/llama/merge_lora.py \ --lora-config r_8_alpha_16 \ - --base-weight checkpoints/fish-speech-1.4 \ + --base-weight checkpoints/fish-speech-1.5 \ --lora-weight results/$project/checkpoints/step_000000010.ckpt \ - --output checkpoints/fish-speech-1.4-yth-lora/ + --output checkpoints/fish-speech-1.5-yth-lora/ ``` !!! note You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data. diff --git a/docs/en/index.md b/docs/en/index.md index cb57b775..667e7fc0 100644 --- a/docs/en/index.md +++ b/docs/en/index.md @@ -179,7 +179,7 @@ pip install -e .[stable] Make sure you are in the terminal inside the docker container, then download the required `vqgan` and `llama` models from our huggingface repository. ```bash - huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 4. Configure environment variables and access WebUI diff --git a/docs/en/inference.md b/docs/en/inference.md index 316c8a9c..91f5ce4c 100644 --- a/docs/en/inference.md +++ b/docs/en/inference.md @@ -15,7 +15,7 @@ Inference support command line, HTTP API and web UI. Download the required `vqgan` and `llama` models from our Hugging Face repository. ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` ### 1. Generate prompt from voice: @@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish- ```bash python tools/vqgan/inference.py \ -i "paimon.wav" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` You should get a `fake.npy` file. @@ -38,7 +38,7 @@ python tools/llama/generate.py \ --text "The text you want to convert" \ --prompt-text "Your reference text" \ --prompt-tokens "fake.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4" \ + --checkpoint-path "checkpoints/fish-speech-1.5" \ --num-samples 2 \ --compile ``` @@ -59,7 +59,7 @@ This command will create a `codes_N` file in the working directory, where N is a ```bash python tools/vqgan/inference.py \ -i "codes_0.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` ## HTTP API Inference @@ -69,8 +69,8 @@ We provide a HTTP API for inference. You can use the following command to start ```bash python -m tools.api \ --listen 0.0.0.0:8080 \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` @@ -120,8 +120,8 @@ You can start the WebUI using the following command: ```bash python -m tools.webui \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` > If you want to speed up inference, you can add the `--compile` parameter. diff --git a/docs/ja/finetune.md b/docs/ja/finetune.md index 68db8cbd..cfc049b1 100644 --- a/docs/ja/finetune.md +++ b/docs/ja/finetune.md @@ -39,7 +39,7 @@ VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。 ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 次に、次のコマンドを実行してセマンティックトークンを抽出できます。 @@ -48,7 +48,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish- python tools/vqgan/extract_vq.py data \ --num-workers 1 --batch-size 16 \ --config-name "firefly_gan_vq" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` !!! note @@ -92,7 +92,7 @@ python tools/llama/build_dataset.py \ 同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。 ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 最後に、次のコマンドを実行して微調整を開始できます。 @@ -120,9 +120,9 @@ python fish_speech/train.py --config-name text2semantic_finetune \ ```bash python tools/llama/merge_lora.py \ --lora-config r_8_alpha_16 \ - --base-weight checkpoints/fish-speech-1.4 \ + --base-weight checkpoints/fish-speech-1.5 \ --lora-weight results/$project/checkpoints/step_000000010.ckpt \ - --output checkpoints/fish-speech-1.4-yth-lora/ + --output checkpoints/fish-speech-1.5-yth-lora/ ``` !!! note 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。 diff --git a/docs/ja/index.md b/docs/ja/index.md index 7c5ad2ec..5f81c5fc 100644 --- a/docs/ja/index.md +++ b/docs/ja/index.md @@ -178,7 +178,7 @@ pip install -e .[stable] Docker コンテナ内のターミナルにいることを確認し、huggingface リポジトリから必要な `vqgan` と `llama` モデルをダウンロードします。 ```bash - huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 4. 環境変数の設定と WebUI へのアクセス diff --git a/docs/ja/inference.md b/docs/ja/inference.md index c4e61450..29476d7c 100644 --- a/docs/ja/inference.md +++ b/docs/ja/inference.md @@ -15,7 +15,7 @@ 必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。 ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` ### 1. 音声からプロンプトを生成する: @@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish- ```bash python tools/vqgan/inference.py \ -i "paimon.wav" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` `fake.npy`ファイルが生成されるはずです。 @@ -38,7 +38,7 @@ python tools/llama/generate.py \ --text "変換したいテキスト" \ --prompt-text "参照テキスト" \ --prompt-tokens "fake.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4" \ + --checkpoint-path "checkpoints/fish-speech-1.5" \ --num-samples 2 \ --compile ``` @@ -59,7 +59,7 @@ python tools/llama/generate.py \ ```bash python tools/vqgan/inference.py \ -i "codes_0.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` ## HTTP API 推論 @@ -69,8 +69,8 @@ python tools/vqgan/inference.py \ ```bash python -m tools.api \ --listen 0.0.0.0:8080 \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` @@ -99,8 +99,8 @@ python -m tools.post_api \ ```bash python -m tools.webui \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` > 推論を高速化したい場合は、`--compile` パラメータを追加できます。 diff --git a/docs/ko/finetune.md b/docs/ko/finetune.md index a13d5e51..85cf11f1 100644 --- a/docs/ko/finetune.md +++ b/docs/ko/finetune.md @@ -38,7 +38,7 @@ VQGAN 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 이후 시맨틱 토큰을 추출하기 위해 아래 명령어를 실행하세요: @@ -47,7 +47,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish- python tools/vqgan/extract_vq.py data \ --num-workers 1 --batch-size 16 \ --config-name "firefly_gan_vq" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` !!! note @@ -91,7 +91,7 @@ python tools/llama/build_dataset.py \ 마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 마지막으로, 아래 명령어를 실행하여 파인튜닝을 시작할 수 있습니다: @@ -119,9 +119,9 @@ python fish_speech/train.py --config-name text2semantic_finetune \ ```bash python tools/llama/merge_lora.py \ --lora-config r_8_alpha_16 \ - --base-weight checkpoints/fish-speech-1.4 \ + --base-weight checkpoints/fish-speech-1.5 \ --lora-weight results/$project/checkpoints/step_000000010.ckpt \ - --output checkpoints/fish-speech-1.4-yth-lora/ + --output checkpoints/fish-speech-1.5-yth-lora/ ``` !!! note diff --git a/docs/ko/index.md b/docs/ko/index.md index 6af58535..d2d8dd91 100644 --- a/docs/ko/index.md +++ b/docs/ko/index.md @@ -179,7 +179,7 @@ pip install -e .[stable] Docker 컨테이너 내부의 터미널에서 아래 명령어를 사용하여 필요한 `vqgan` 및 `llama` 모델을 Huggingface 리포지토리에서 다운로드합니다. ```bash - huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 4. 환경 변수 설정 및 WebUI 접근 diff --git a/docs/ko/inference.md b/docs/ko/inference.md index 65b3ec58..201f8a89 100644 --- a/docs/ko/inference.md +++ b/docs/ko/inference.md @@ -15,7 +15,7 @@ 필요한 `vqgan` 및 `llama` 모델을 Hugging Face 리포지토리에서 다운로드하세요. ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` ### 1. 음성에서 프롬프트 생성: @@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish- ```bash python tools/vqgan/inference.py \ -i "paimon.wav" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` 이 명령을 실행하면 `fake.npy` 파일을 얻게 됩니다. @@ -38,7 +38,7 @@ python tools/llama/generate.py \ --text "변환할 텍스트" \ --prompt-text "참고할 텍스트" \ --prompt-tokens "fake.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4" \ + --checkpoint-path "checkpoints/fish-speech-1.5" \ --num-samples 2 \ --compile ``` @@ -59,7 +59,7 @@ python tools/llama/generate.py \ ```bash python tools/vqgan/inference.py \ -i "codes_0.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` ## HTTP API 추론 @@ -69,8 +69,8 @@ python tools/vqgan/inference.py \ ```bash python -m tools.api \ --listen 0.0.0.0:8080 \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` @@ -118,8 +118,8 @@ python -m tools.post_api \ ```bash python -m tools.webui \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` diff --git a/docs/pt/finetune.md b/docs/pt/finetune.md index f57d92c7..7e7eb5c8 100644 --- a/docs/pt/finetune.md +++ b/docs/pt/finetune.md @@ -39,7 +39,7 @@ Você precisa converter seu conjunto de dados para o formato acima e colocá-lo Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos: @@ -48,7 +48,7 @@ Em seguida, você pode executar o seguinte comando para extrair os tokens semân python tools/vqgan/extract_vq.py data \ --num-workers 1 --batch-size 16 \ --config-name "firefly_gan_vq" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` !!! note @@ -92,7 +92,7 @@ Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.prot Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` E então, execute o seguinte comando para iniciar o ajuste fino: @@ -120,9 +120,9 @@ Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares an ```bash python tools/llama/merge_lora.py \ --lora-config r_8_alpha_16 \ - --base-weight checkpoints/fish-speech-1.4 \ + --base-weight checkpoints/fish-speech-1.5 \ --lora-weight results/$project/checkpoints/step_000000010.ckpt \ - --output checkpoints/fish-speech-1.4-yth-lora/ + --output checkpoints/fish-speech-1.5-yth-lora/ ``` !!! note É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD). diff --git a/docs/pt/index.md b/docs/pt/index.md index 05e27ff7..60cab972 100644 --- a/docs/pt/index.md +++ b/docs/pt/index.md @@ -175,7 +175,7 @@ pip install -e .[stable] Certifique-se de estar no terminal do contêiner Docker e, em seguida, baixe os modelos necessários `vqgan` e `llama` do nosso repositório HuggingFace. ```bash - huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 4. Configure as variáveis de ambiente e acesse a WebUI diff --git a/docs/pt/inference.md b/docs/pt/inference.md index 8cbaa4ee..5223b625 100644 --- a/docs/pt/inference.md +++ b/docs/pt/inference.md @@ -15,7 +15,7 @@ Suporte para inferência por linha de comando, API HTTP e interface web (WebUI). Baixe os modelos `vqgan` e `llama` necessários do nosso repositório Hugging Face. ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` ### 1. Gerar prompt a partir da voz: @@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish- ```bash python tools/vqgan/inference.py \ -i "paimon.wav" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` Você deverá obter um arquivo `fake.npy`. @@ -38,7 +38,7 @@ python tools/llama/generate.py \ --text "O texto que você deseja converter" \ --prompt-text "Seu texto de referência" \ --prompt-tokens "fake.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4" \ + --checkpoint-path "checkpoints/fish-speech-1.5" \ --num-samples 2 \ --compile ``` @@ -59,7 +59,7 @@ Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é u ```bash python tools/vqgan/inference.py \ -i "codes_0.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` ## Inferência por API HTTP @@ -69,7 +69,7 @@ Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para ```bash python -m tools.api \ --listen 0.0.0.0:8080 \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` @@ -99,8 +99,8 @@ Para iniciar a WebUI de Inferência execute o seguinte comando: ```bash python -m tools.webui \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` > Para acelerar a inferência, adicione o parâmetro `--compile`. diff --git a/docs/zh/finetune.md b/docs/zh/finetune.md index f7db80c9..1b65f8b1 100644 --- a/docs/zh/finetune.md +++ b/docs/zh/finetune.md @@ -37,13 +37,13 @@ 确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 对于中国大陆用户, 可使用 mirror 下载. ```bash -HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 随后可运行以下命令来提取语义 token: @@ -52,7 +52,7 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech python tools/vqgan/extract_vq.py data \ --num-workers 1 --batch-size 16 \ --config-name "firefly_gan_vq" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` !!! note @@ -96,13 +96,13 @@ python tools/llama/build_dataset.py \ 同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令: ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 对于中国大陆用户, 可使用 mirror 下载. ```bash -HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 最后, 你可以运行以下命令来启动微调: @@ -130,9 +130,9 @@ python fish_speech/train.py --config-name text2semantic_finetune \ ```bash python tools/llama/merge_lora.py \ --lora-config r_8_alpha_16 \ - --base-weight checkpoints/fish-speech-1.4 \ + --base-weight checkpoints/fish-speech-1.5 \ --lora-weight results/$project/checkpoints/step_000000010.ckpt \ - --output checkpoints/fish-speech-1.4-yth-lora/ + --output checkpoints/fish-speech-1.5-yth-lora/ ``` !!! note diff --git a/docs/zh/index.md b/docs/zh/index.md index f108c0a6..d12835dc 100644 --- a/docs/zh/index.md +++ b/docs/zh/index.md @@ -176,13 +176,13 @@ pip install -e .[stable] 确保您在 docker 容器内的终端,然后再从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。 ```bash - huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 对于中国大陆用户,可以通过镜像站下载。 ```bash - HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 + HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 4. 配置环境变量,访问 WebUI diff --git a/docs/zh/inference.md b/docs/zh/inference.md index 0c679be0..6426d218 100644 --- a/docs/zh/inference.md +++ b/docs/zh/inference.md @@ -15,13 +15,13 @@ 从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。 ```bash -huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` 对于中国大陆用户,可使用 mirror 下载。 ```bash -HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4 +HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 ``` ### 1. 从语音生成 prompt: @@ -32,7 +32,7 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech ```bash python tools/vqgan/inference.py \ -i "paimon.wav" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` 你应该能得到一个 `fake.npy` 文件. @@ -44,7 +44,7 @@ python tools/llama/generate.py \ --text "要转换的文本" \ --prompt-text "你的参考文本" \ --prompt-tokens "fake.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4" \ + --checkpoint-path "checkpoints/fish-speech-1.5" \ --num-samples 2 \ --compile ``` @@ -65,7 +65,7 @@ python tools/llama/generate.py \ ```bash python tools/vqgan/inference.py \ -i "codes_0.npy" \ - --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" ``` ## HTTP API 推理 @@ -75,8 +75,8 @@ python tools/vqgan/inference.py \ ```bash python -m tools.api \ --listen 0.0.0.0:8080 \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` > 如果你想要加速推理,可以加上`--compile`参数。 @@ -128,8 +128,8 @@ python -m tools.post_api \ ```bash python -m tools.webui \ - --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ - --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ + --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ + --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ --decoder-config-name firefly_gan_vq ``` > 如果你想要加速推理,可以加上`--compile`参数。 diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py index 9bbc1cdb..20d8ab32 100644 --- a/fish_speech/conversation.py +++ b/fish_speech/conversation.py @@ -2,41 +2,10 @@ from typing import Literal import torch -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast - -IM_START_TOKEN = "<|im_start|>" -IM_END_TOKEN = "<|im_end|>" -SEMANTIC_TOKEN = "<|semantic|>" -MEL_TOKEN = "<|mel|>" -PHONEME_START_TOKEN = "<|phoneme_start|>" -PHONEME_END_TOKEN = "<|phoneme_end|>" -ALL_SPECIAL_TOKENS = [ - IM_START_TOKEN, - IM_END_TOKEN, - SEMANTIC_TOKEN, - MEL_TOKEN, - PHONEME_START_TOKEN, - PHONEME_END_TOKEN, -] - -CODEBOOK_PAD_TOKEN_ID = 0 - - -class FishTokenizerConfig(PretrainedConfig): - share_codebook_embeddings: bool = True - codebook_size: int = 1024 - num_codebooks: int = 8 +from .tokenizer import MODALITY_TOKENS, FishTokenizer -class FishTokenizerFast(PreTrainedTokenizerFast): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True) - self.codebook_size = kwargs.pop("codebook_size", 1024) - self.num_codebooks = kwargs.pop("num_codebooks", 8) - - -AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast) +CODEBOOK_PAD_TOKEN_ID = 0 @dataclass(kw_only=True) @@ -54,77 +23,72 @@ class TextPart(BasePart): text: str -@dataclass(kw_only=True) -class MelPart(BasePart): - mels: torch.Tensor - - @dataclass(kw_only=True) class EncodedMessage: tokens: torch.Tensor labels: torch.Tensor + vq_mask_tokens: torch.Tensor | None = None + vq_mask_labels: torch.Tensor | None = None vq_parts: list[torch.Tensor] - mel_parts: list[torch.Tensor] vq_require_losses: torch.Tensor | None = None @dataclass(kw_only=True) class Message: role: Literal["system", "user", "assistant"] - parts: list[VQPart | TextPart | MelPart] = field(default_factory=list) + parts: list[VQPart | TextPart] = field(default_factory=list) add_im_start: bool = True add_im_end: bool = True cal_loss: bool = False + modality: Literal["text", "voice", "interleave"] | None = None # By default, ignore the loss of the auto-generated im_start token ignore_im_start_loss: bool = True def encode( self: "Message", - tokenizer: AutoTokenizer, + tokenizer: FishTokenizer, ) -> EncodedMessage: all_tokens = [] all_labels = [] # Multi-modal tokens vq_parts = [] - mel_parts = [] - - semantic_id, mel_id = tokenizer.convert_tokens_to_ids( - [SEMANTIC_TOKEN, MEL_TOKEN] - ) + vq_masks = [] parts = self.parts.copy() if self.add_im_start: - parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n")) + modality_token = MODALITY_TOKENS[self.modality] if self.modality else "" + parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}")) if self.add_im_end: parts.append(TextPart(text="<|im_end|>")) for part in parts: if isinstance(part, TextPart): - tokens = tokenizer.encode( - part.text, - add_special_tokens=False, - truncation=False, - return_tensors="pt", - ).int()[0] + tokens = torch.tensor( + tokenizer.encode(part.text), + dtype=torch.int, + ) elif isinstance(part, VQPart): - tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id - codes = part.codes.clone() + 1 - - if getattr(tokenizer, "share_codebook_embeddings", True) is False: - for i in range(len(codes)): - codes[i] += tokenizer.codebook_size * i - - vq_parts.append(codes) - elif isinstance(part, MelPart): - tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id - mel_parts.append(part.mels) + curr_codes = part.codes.clone() + tokens = torch.tensor( + [ + tokenizer.semantic_id_to_token_id[i.item()] + for i in curr_codes[0].int() + ], + dtype=torch.int, + ) + vq_parts.append(curr_codes) else: raise ValueError(f"Unsupported part type: {type(part)}") all_tokens.append(tokens) + if isinstance(part, VQPart): + vq_masks.append(torch.ones_like(tokens, dtype=torch.bool)) + else: + vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) + if self.cal_loss: all_labels.append(tokens.clone()) else: @@ -132,7 +96,9 @@ def encode( tokens = torch.cat(all_tokens, dim=0) labels = torch.cat(all_labels, dim=0) - assert tokens.shape == labels.shape + vq_masks = torch.cat(vq_masks, dim=0) + + assert tokens.shape == labels.shape == vq_masks.shape if self.ignore_im_start_loss and self.add_im_start: labels[: len(all_tokens[0])] = -100 @@ -141,7 +107,8 @@ def encode( tokens=tokens, labels=labels, vq_parts=vq_parts, - mel_parts=mel_parts, + vq_mask_tokens=vq_masks, + vq_mask_labels=vq_masks, ) @@ -149,17 +116,23 @@ def encode( class Conversation: messages: list[Message] + def __init__(self: "Conversation", messages: list[Message] | None = None): + self.messages = messages or [] + def encode( self: "Conversation", - tokenizer: AutoTokenizer, + tokenizer: FishTokenizer, add_shift: bool = True, + ignore_loss_tokens: list[str] = [], ) -> EncodedMessage: # Build the input_ids and labels tokens = [] labels = [] vq_parts = [] - mel_parts = [] + vq_mask_tokens = [] + vq_mask_labels = [] vq_require_losses = [] + ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens] for message in self.messages: encoded = message.encode( @@ -168,16 +141,25 @@ def encode( tokens.append(encoded.tokens) labels.append(encoded.labels) vq_parts.extend(encoded.vq_parts) - mel_parts.extend(encoded.mel_parts) + vq_mask_tokens.append(encoded.vq_mask_tokens) + vq_mask_labels.append(encoded.vq_mask_labels) vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts)) tokens = torch.cat(tokens, dim=0) labels = torch.cat(labels, dim=0) + vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0) + vq_mask_labels = torch.cat(vq_mask_labels, dim=0) vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) if add_shift: tokens = tokens[:-1] labels = labels[1:] + vq_mask_tokens = vq_mask_tokens[:-1] + vq_mask_labels = vq_mask_labels[1:] + + for i in ignore_loss_token_ids: + assert i != -100 and i is not None + labels[labels == i] = -100 assert tokens.dtype in [ torch.int, @@ -188,15 +170,18 @@ def encode( tokens=tokens, labels=labels, vq_parts=vq_parts, - mel_parts=mel_parts, + vq_mask_tokens=vq_mask_tokens, + vq_mask_labels=vq_mask_labels, vq_require_losses=vq_require_losses, ) def encode_for_inference( self: "Conversation", - tokenizer: AutoTokenizer, + tokenizer: FishTokenizer, num_codebooks: int, ) -> EncodedMessage: + # self.visualize(tokenizer) + encoded = self.encode(tokenizer, add_shift=False) tokens = encoded.tokens values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) @@ -205,24 +190,47 @@ def encode_for_inference( if encoded.vq_parts is None or len(encoded.vq_parts) == 0: return values - semantic_id, mel_id = tokenizer.convert_tokens_to_ids( - [SEMANTIC_TOKEN, MEL_TOKEN] - ) vq_parts = encoded.vq_parts + vq_parts = [part.to(values.device) for part in vq_parts] vq_parts = torch.cat(vq_parts, dim=1) - values[1:, tokens == semantic_id] = vq_parts + values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id + values[1:, encoded.vq_mask_tokens] = vq_parts + return values - def visualize(self: "Conversation", tokenizer: AutoTokenizer): - encoded = self.encode(tokenizer, add_shift=False) + def visualize( + self: "Conversation", + tokenizer: FishTokenizer, + ignore_loss_tokens: list[str] = [], + ): + encoded = self.encode( + tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens + ) - print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="") - print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="") + # Colors for alternating tokens + colors = { + "blue": "\033[94m", # Light blue + "cyan": "\033[96m", # Cyan + "green": "\033[92m", # Light green + "dark_green": "\033[32m", # Dark green + } + blue_idx = 0 + green_idx = 0 + + def print_in_blue(x): + nonlocal blue_idx + color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"] + print(f"{color}{x}\033[0m", end="") + blue_idx += 1 + + def print_in_green(x): + nonlocal green_idx + color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"] + print(f"{color}{x}\033[0m", end="") + green_idx += 1 for tok, lab in zip(encoded.tokens, encoded.labels): - val = tokenizer.decode(tok, skip_special_tokens=False) - if val == "\n": - val = "\\n\n" + val = tokenizer.decode([tok]) if lab == -100: print_in_green(val) @@ -231,6 +239,9 @@ def visualize(self: "Conversation", tokenizer: AutoTokenizer): print() + def append(self: "Conversation", message: Message): + self.messages.append(message) + if __name__ == "__main__": message0 = Message( @@ -248,7 +259,7 @@ def visualize(self: "Conversation", tokenizer: AutoTokenizer): cal_loss=True, ) conversation = Conversation([message0, message1]) - tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") + tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") conversation.visualize(tokenizer) encoded = conversation.encode(tokenizer) diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py index 6ea15e59..1811f091 100644 --- a/fish_speech/models/text2semantic/llama.py +++ b/fish_speech/models/text2semantic/llama.py @@ -16,7 +16,7 @@ from torch.utils.checkpoint import checkpoint from transformers import AutoTokenizer -from fish_speech.conversation import SEMANTIC_TOKEN +from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer from fish_speech.utils import RankedLogger from .lora import LoraConfig, setup_lora @@ -61,6 +61,7 @@ class BaseModelArgs: # Dummy vars is_reward_model: bool = False share_codebook_embeddings: bool = True + scale_codebook_embeddings: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -164,13 +165,17 @@ class BaseTransformerForwardResult: class BaseTransformer(nn.Module): def __init__( - self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True + self, + config: BaseModelArgs, + tokenizer: FishTokenizer | AutoTokenizer, + init_weights: bool = True, ) -> None: super().__init__() self.config = config self.tokenizer = tokenizer - - self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN) + self.semantic_token_ids = [ + tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS + ] # Slow transformer self.embeddings = nn.Embedding( @@ -245,8 +250,10 @@ def embed(self, x: Tensor) -> Tensor: vocab_embeds = [self.embeddings(x[:, 0])] for i in range(self.config.num_codebooks): emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) - emb[x[:, 0] != self.semantic_token_id] = 0 - vocab_embeds.append(emb) + semantic_token_ids_tensor = torch.tensor( + self.semantic_token_ids, device=x.device + ) + emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0 x = torch.stack(vocab_embeds, dim=3) x = x.sum(dim=3) @@ -294,20 +301,45 @@ def forward( def forward_generate( self, - x: Tensor, + inp: Tensor, input_pos: Optional[Tensor] = None, + vq_masks: Optional[Tensor] = None, # this is not used in fact return_all: bool = False, ) -> BaseTransformerForwardResult: # This is used for generation, optimized for torch compile - assert ( - self.max_seq_len != -1 and self.max_batch_size != -1 - ), "Please call setup_caches before forward_generate" + # assert ( + # self.max_seq_len != -1 and self.max_batch_size != -1 + # ), "Please call setup_caches before forward_generate" - x = self.embed(x) + embeds = [] + for i in range(self.config.num_codebooks): + if self.config.share_codebook_embeddings: + _tokens = inp[:, i + 1] + i * self.config.codebook_size + else: + _tokens = inp[:, i + 1] - mask = self.causal_mask[ - None, None, input_pos, : self.max_seq_len - ] # (B, N, Q, K) + emb = self.codebook_embeddings(_tokens) + embeds.append(emb) + + vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1) + # if self.config.use_codebook_mlp: + # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks + # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum) + + vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & ( + inp[:, 0] <= self.tokenizer.semantic_end_id + ) + + vq_embeds_sum[~vq_masks] = 0 + x = self.embeddings(inp[:, 0]) + vq_embeds_sum + + if input_pos is None: + input_pos = torch.arange(inp.shape[-1], device=x.device) + max_seq_len = inp.shape[-1] + else: + max_seq_len = self.max_seq_len + + mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K) freqs_cis = self.freqs_cis[input_pos] for layer in self.layers: @@ -320,7 +352,9 @@ def forward_generate( # We got slow_out here slow_out = self.norm(x) - if self.config.tie_word_embeddings: + if self.config.is_reward_model: + token_logits = self.score_output(slow_out) + elif self.config.tie_word_embeddings: token_logits = F.linear(slow_out, self.embeddings.weight) else: token_logits = self.output(slow_out) @@ -348,6 +382,7 @@ def from_pretrained( max_length: int | None = None, lora_config: LoraConfig | None = None, rope_base: int | None = None, + is_agent: bool = False, ) -> "BaseTransformer": config = BaseModelArgs.from_pretrained(str(path)) if max_length is not None: @@ -366,7 +401,12 @@ def from_pretrained( case _: raise ValueError(f"Unknown model type: {config.model_type}") - tokenizer = AutoTokenizer.from_pretrained(str(path)) + if is_agent: + tokenizer = AutoTokenizer.from_pretrained(str(path)) + else: + tokenizer_path = str(path) + "/tokenizer.tiktoken" + tokenizer = FishTokenizer(tokenizer_path) + log.info(f"Loading model from {path}, config: {config}") model = model_cls(config, tokenizer=tokenizer) @@ -452,7 +492,7 @@ def save_pretrained(self, path: str, drop_lora: bool = False): class NaiveTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: super().__init__(config, init_weights=False, tokenizer=tokenizer) self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) @@ -498,7 +538,7 @@ def forward_generate( class DualARTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: super().__init__(config, init_weights=False, tokenizer=tokenizer) # Project to fast dim if needed @@ -654,9 +694,12 @@ def forward_generate_fast( return codebook_logits def forward_generate( - self, x: Tensor, input_pos: Optional[Tensor] = None + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + vq_masks: Optional[Tensor] = None, ) -> TransformerForwardResult: - x = super().forward_generate(x, input_pos) + x = super().forward_generate(x, input_pos, vq_masks) x.hidden_states = self.fast_project_in(x.hidden_states) return x diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py index dbaf843d..fe1b331c 100644 --- a/fish_speech/text/clean.py +++ b/fish_speech/text/clean.py @@ -1,33 +1,8 @@ import re SYMBOLS_MAPPING = { - "\n": "", - "…": ".", - "“": "'", - "”": "'", "‘": "'", "’": "'", - "【": "", - "】": "", - "[": "", - "]": "", - "(": "", - ")": "", - "(": "", - ")": "", - "・": "", - "·": "", - "「": "'", - "」": "'", - "《": "'", - "》": "'", - "—": "", - "~": "", - "~": "", - ":": ",", - ";": ",", - ";": ",", - ":": ",", } REPLACE_SYMBOL_REGEX = re.compile( @@ -57,6 +32,6 @@ def clean_text(text): text = EMOJI_REGEX.sub(r"", text) # Remove continuous periods (...) and commas (,,,) - text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text) + text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text) return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py index d4bb9954..df079add 100644 --- a/fish_speech/text/spliter.py +++ b/fish_speech/text/spliter.py @@ -4,7 +4,7 @@ from fish_speech.text.clean import clean_text -def utf_8_len(text): +def utf_8_len(text: str): return len(text.encode("utf-8")) diff --git a/fish_speech/tokenizer.py b/fish_speech/tokenizer.py new file mode 100644 index 00000000..897cd060 --- /dev/null +++ b/fish_speech/tokenizer.py @@ -0,0 +1,152 @@ +import base64 +import json +import logging +from pathlib import Path + +import tiktoken + +logger = logging.getLogger(__name__) + +# This is a modified version of the default pattern from GPT-4o, that better handles punctuations. +FISH_TIKTOKEN_PATTERN = "|".join( + [ + r"(?i:'s|'t|'re|'ve|'m|'ll|'d)", + r"\p{P}", + r"[^\r\n\p{L}\p{N}]?\p{L}+", + r"\p{N}", + r" ?[^\s\p{L}\p{N}]+[\r\n]*", + r"\s*[\r\n]+", + r"\s+(\?!\S)", + r"\s+", + ] +) +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +BOS_TOKEN = "<|begin_of_text|>" +EOS_TOKEN = "<|end_of_text|>" +PAD_TOKEN = "<|pad|>" +IM_START_TOKEN = "<|im_start|>" +IM_END_TOKEN = "<|im_end|>" + +MODALITY_TEXT_TOKEN = "<|text|>" +MODALITY_VOICE_TOKEN = "<|voice|>" +MODALITY_INTERLEAVE_TOKEN = "<|interleave|>" +MODALITY_TOKENS = { + "text": MODALITY_TEXT_TOKEN, + "voice": MODALITY_VOICE_TOKEN, + "interleave": MODALITY_INTERLEAVE_TOKEN, +} + +PLACEHOLDER_TOKEN = [""] * 4 +for i in range(4): + PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>" + +SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>" +SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)] + +# Warning: when you add a new special token, you should only add it to the end of the list. +ALL_SPECIAL_TOKENS = [ + BOS_TOKEN, + EOS_TOKEN, + PAD_TOKEN, + IM_START_TOKEN, + IM_END_TOKEN, + PLACEHOLDER_TOKEN[0], + PLACEHOLDER_TOKEN[1], + PLACEHOLDER_TOKEN[2], + PLACEHOLDER_TOKEN[3], + MODALITY_TEXT_TOKEN, + MODALITY_VOICE_TOKEN, + MODALITY_INTERLEAVE_TOKEN, + *SEMANTIC_TOKENS, +] + + +class FishTokenizer: + def __init__(self, model_path: str) -> None: + mergeable_ranks = self.load_tiktoken_bpe(model_path) + special_token_begin = len(mergeable_ranks) + self.all_special_tokens_with_ids = { + token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS) + } + self.semantic_id_to_token_id = { + i: self.all_special_tokens_with_ids[token] + for i, token in enumerate(SEMANTIC_TOKENS) + } + self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]] + self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]] + + self.tkt_model = tiktoken.core.Encoding( + name=Path(model_path).stem, + pat_str=FISH_TIKTOKEN_PATTERN, + mergeable_ranks=mergeable_ranks, + special_tokens=self.all_special_tokens_with_ids, + ) + + @staticmethod + def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: + data = {} + for line in open(tiktoken_bpe_file).read().splitlines(): + if not line: + continue + token, rank = line.split() + data[base64.b64decode(token)] = int(rank) + return data + + def get_token_id(self, token: str) -> int: + return self.all_special_tokens_with_ids[token] + + def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]: + assert isinstance(s, str) + + subs = [] + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): + subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) + + if allowed_special is True: + allowed_special = self.tkt_model.special_tokens_set + elif allowed_special is False: + allowed_special = set() + + return sum( + self.tkt_model.encode_batch( + subs, allowed_special=allowed_special, disallowed_special=set() + ), + start=[], + ) + + def decode(self, tokens: list[int]) -> str: + return self.tkt_model.decode(tokens) + + def save_pretrained(self, path: str): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + with open(path / "tokenizer.tiktoken", "w") as f: + for token, rank in self.tkt_model._mergeable_ranks.items(): + f.write(f"{base64.b64encode(token).decode()} {rank}\n") + + with open(path / "special_tokens.json", "w") as f: + json.dump( + self.all_special_tokens_with_ids, + f, + indent=2, + ensure_ascii=False, + ) + + @staticmethod + def from_pretrained(path: str): + return FishTokenizer(Path(path) / "tokenizer.tiktoken") + + +if __name__ == "__main__": + tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken") + tokenizer.save_pretrained("checkpoints/fish-speech-0.5B") + tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B") + + print( + [ + tokenizer.decode([i]) + for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}") + ] + ) diff --git a/pyproject.toml b/pyproject.toml index ad94489b..d894afcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ dependencies = [ "opencc-python-reimplemented==0.1.7", "silero-vad", "ormsgpack", + "tiktoken>=0.8.0", + "pydantic==2.9.2", ] [project.optional-dependencies] diff --git a/tools/api.py b/tools/api.py index a5d47276..275021c3 100644 --- a/tools/api.py +++ b/tools/api.py @@ -1,4 +1,5 @@ import io +import json import os import queue import re @@ -32,7 +33,6 @@ ) from kui.asgi.routing import MultimethodRoutes from loguru import logger -from transformers import AutoTokenizer pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) import struct @@ -43,12 +43,14 @@ from funasr import AutoModel from silero_vad import get_speech_timestamps, load_silero_vad -from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN from fish_speech.models.text2semantic.llama import BaseModelArgs # from fish_speech.models.vqgan.lit_module import VQGAN from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture from fish_speech.text.chn_text_norm.text import Text as ChnNormedText + +# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN +from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer from fish_speech.utils import autocast_exclude_mps, set_seed from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text from tools.llama.generate import ( @@ -381,14 +383,13 @@ def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]): def execute_request( input_queue: queue.Queue, - tokenizer: AutoTokenizer, + tokenizer: FishTokenizer, config: BaseModelArgs, request: ServeRequest, device: str = "cuda:0", ): - semantic_id, im_end_id = tokenizer.convert_tokens_to_ids( - [SEMANTIC_TOKEN, IM_END_TOKEN] - ) + + im_end_id = tokenizer.get_token_id(IM_END_TOKEN) messages = [] for message in request.messages: messages.append(message.to_conversation_message()) @@ -397,7 +398,13 @@ def execute_request( # assert messages[-1].role == "user", "The last message must be from the user" if messages[-1].role == "user": - messages.append(Message(role="assistant", parts=[], add_im_end=False)) + messages.append( + Message(role="assistant", parts=[], add_im_end=False, modality="voice") + ) + elif messages[-1].role == "raw": + messages[-1].add_im_start = False + messages[-1].add_im_end = False + messages[-1].modality = "voice" else: assert ( messages[-1].role == "assistant" @@ -405,6 +412,8 @@ def execute_request( messages[-1].add_im_end = False conv = Conversation(messages=messages) + + # conv.visualize(tokenizer) prompt = conv.encode_for_inference( tokenizer=tokenizer, num_codebooks=config.num_codebooks ).to(device) @@ -422,7 +431,6 @@ def execute_request( "prompt": prompt, "max_new_tokens": request.max_new_tokens, "im_end_id": im_end_id, - "semantic_id": semantic_id, "temperature": request.temperature, "top_p": request.top_p, "repetition_penalty": request.repetition_penalty, @@ -478,10 +486,13 @@ def send_reset_buffer(sample_id): ) continue - if tokens[0] == semantic_id and request.streaming: + is_semantic = ( + tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id + ) + if is_semantic and request.streaming: yield from send_reset_buffer(sample_id) # Streaming vq - _tokens = tokens[1:].clone() - 1 + _tokens = tokens[1:].clone() if config.share_codebook_embeddings is False: for i in range(len(_tokens)): @@ -494,13 +505,13 @@ def send_reset_buffer(sample_id): continue # Not streaming vq - if tokens[0] == semantic_id: + if is_semantic: yield from send_reset_buffer(sample_id) # None streaming vq if len(parts[sample_id]) == 0 or not isinstance( parts[sample_id][-1], ServeVQPart ): - _tokens = tokens[1:].clone() - 1 + _tokens = tokens[1:].clone() if config.share_codebook_embeddings is False: for i in range(len(_tokens)): @@ -509,14 +520,14 @@ def send_reset_buffer(sample_id): parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) else: for codebook_id, value in enumerate(tokens[1:, :]): - val = value.item() - 1 + val = value.item() if config.share_codebook_embeddings is False: val -= config.codebook_size * codebook_id parts[sample_id][-1].codes[codebook_id].append(val) continue - if tokens[0] != semantic_id: + if not is_semantic: # Stream text decode is not supported now decode_buffer[sample_id].append(tokens[0, 0]) @@ -776,7 +787,6 @@ async def api_health(): """ Health check """ - return JSONResponse({"status": "ok"}) @@ -871,11 +881,6 @@ def initialize_app(app: Kui): args = parse_args() # args same as ones in other processes args.precision = torch.half if args.half else torch.bfloat16 - # Check if CUDA is available - if not torch.cuda.is_available(): - logger.info("CUDA is not available, running on CPU.") - args.device = "cpu" - if args.load_asr_model: logger.info(f"Loading ASR model...") asr_model = load_asr_model(device=args.device) @@ -922,7 +927,7 @@ def initialize_app(app: Kui): max_new_tokens=0, chunk_length=200, top_p=0.7, - repetition_penalty=1.2, + repetition_penalty=1.5, temperature=0.7, emotion=None, format="wav", diff --git a/tools/llama/generate.py b/tools/llama/generate.py index e2beba97..f5979831 100644 --- a/tools/llama/generate.py +++ b/tools/llama/generate.py @@ -17,9 +17,16 @@ from tqdm import tqdm from transformers import AutoTokenizer -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.conversation import ( + CODEBOOK_PAD_TOKEN_ID, + Conversation, + Message, + TextPart, + VQPart, +) from fish_speech.models.text2semantic.llama import BaseModelArgs from fish_speech.text import clean_text, split_text +from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" torch._inductor.config.coordinate_descent_tuning = True @@ -145,8 +152,8 @@ def decode_one_token_ar_agent( model: DualARTransformer, x: torch.Tensor, input_pos: torch.Tensor, + semantic_ids: list, previous_tokens: torch.Tensor = None, - semantic_id: int = 32003, **sampling_kwargs, ) -> torch.Tensor: # print(x, input_pos) @@ -190,19 +197,13 @@ def decode_one_token_ar_agent( codebooks.append(a) codebooks = torch.stack(codebooks, dim=1) + semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device) codebooks[:, 1:, :] = torch.masked_fill( - codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID + codebooks[:, 1:, :], + ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor), + CODEBOOK_PAD_TOKEN_ID, ) - # for i in range(codebooks.size(1) - 1): - # codebooks[:, i + 1, :] = torch.masked_fill( - # codebooks[:, i + 1, :], - # codebooks[:, :1, :] != semantic_id, - # CODEBOOK_PAD_TOKEN_ID + i * 1024, - # ) - - # print(codebooks) - return codebooks @@ -210,8 +211,8 @@ def decode_one_token_naive_agent( model: NaiveTransformer, x: torch.Tensor, input_pos: torch.Tensor, + semantic_ids: list, previous_tokens: torch.Tensor = None, - semantic_id: int = 32003, **sampling_kwargs, ) -> torch.Tensor: x = model.forward_generate(x, input_pos) @@ -236,8 +237,11 @@ def decode_one_token_naive_agent( ) codebooks = torch.stack(codebooks, dim=1) + semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device) codebooks[:, 1:, :] = torch.masked_fill( - codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID + codebooks[:, 1:, :], + ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor), + CODEBOOK_PAD_TOKEN_ID, ) return codebooks @@ -247,8 +251,8 @@ def decode_one_token_ar( model: DualARTransformer, x: torch.Tensor, input_pos: torch.Tensor, + semantic_ids: list, previous_tokens: torch.Tensor = None, - semantic_id: int = 0, **sampling_kwargs, ) -> torch.Tensor: x = model.forward_generate(x, input_pos) @@ -261,21 +265,32 @@ def decode_one_token_ar( codebooks = [ sample( x.logits, - previous_tokens=None, # Disable repetition penalty for the token codebook + previous_tokens=( + previous_tokens[0] if previous_tokens is not None else None + ), # Disable repetition penalty for the token codebook **sampling_kwargs_main, )[0] ] - x = x.hidden_states + hidden_states = x.hidden_states # Cleanup the cache for layer in model.fast_layers: layer.attention.kv_cache.k_cache.fill_(0) layer.attention.kv_cache.v_cache.fill_(0) - for codebook_idx in range(model.config.num_codebooks): - input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long) - logits = model.forward_generate_fast(x, input_pos) + input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long) + model.forward_generate_fast(hidden_states, input_pos) + a = codebooks[0] - model.tokenizer.semantic_begin_id + a[a < 0] = 0 + hidden_states = model.fast_embeddings(a) + codebooks.append(a) + + for codebook_idx in range(1, model.config.num_codebooks): + input_pos = torch.tensor( + [codebook_idx], device=hidden_states.device, dtype=torch.long + ) + logits = model.forward_generate_fast(hidden_states, input_pos) a = sample( logits, previous_tokens=( @@ -285,14 +300,16 @@ def decode_one_token_ar( ), **sampling_kwargs, )[0] - x = model.fast_embeddings(a) + hidden_states = model.fast_embeddings(a) codebooks.append(a) codebooks = torch.stack(codebooks, dim=0) - codebooks[1:, :] = torch.masked_fill( - codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID - ) + # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device) + # codebooks[1:, :] = torch.masked_fill( + # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID + # ) + # print(codebooks) return codebooks @@ -337,9 +354,8 @@ def decode_n_tokens( cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, - im_end_id: int = 4, + semantic_ids: list, decode_one_token=decode_one_token_naive, - semantic_id: int = 0, **sampling_kwargs, ): previous_tokens = torch.zeros( @@ -368,7 +384,7 @@ def decode_n_tokens( x=cur_token, input_pos=input_pos, previous_tokens=window, - semantic_id=semantic_id, + semantic_ids=semantic_ids, **sampling_kwargs, ) @@ -378,7 +394,7 @@ def decode_n_tokens( model.config.num_codebooks + 1, -1 ) - if cur_token[0, 0, -1] == im_end_id: + if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN): break return previous_tokens[:, : i + 1] @@ -391,7 +407,6 @@ def generate( model: NaiveTransformer, prompt: torch.Tensor, max_new_tokens: int, - im_end_id: int = 4, decode_one_token=decode_one_token_naive, **sampling_kwargs, ) -> torch.Tensor: @@ -401,7 +416,10 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens T = prompt.size(1) - semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>") + # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>") + semantic_ids = [ + model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024) + ] if max_new_tokens: if T + max_new_tokens > model.config.max_seq_len: @@ -435,7 +453,7 @@ def generate( model, prompt.view(1, codebook_dim, -1), input_pos, - semantic_id=semantic_id, + semantic_ids=semantic_ids, **sampling_kwargs, ) seq[:, T : T + 1] = next_token @@ -446,9 +464,8 @@ def generate( next_token.view(1, codebook_dim, -1), input_pos, max_new_tokens - 1, - im_end_id=im_end_id, decode_one_token=decode_one_token, - semantic_id=semantic_id, + semantic_ids=semantic_ids, **sampling_kwargs, ) # x = torch.cat(generated_tokens, dim=1) @@ -463,8 +480,8 @@ def decode_n_tokens_agent( cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, + semantic_ids: list, im_end_id: int = 4, - semantic_id: int = 32003, decode_one_token=decode_one_token_naive_agent, early_stop_threshold: float = 0.6, **sampling_kwargs, @@ -495,7 +512,7 @@ def decode_n_tokens_agent( x=cur_token, input_pos=input_pos, previous_tokens=window, - semantic_id=semantic_id, + semantic_ids=semantic_ids, **sampling_kwargs, ) @@ -529,8 +546,8 @@ def generate_agent( model: BaseTransformer, prompt: torch.Tensor, max_new_tokens: int, + semantic_ids: list, im_end_id: int = 4, - semantic_id: int = 32003, decode_one_token=decode_one_token_naive_agent, num_samples: int = 1, early_stop_threshold: float = 0.6, @@ -574,7 +591,7 @@ def generate_agent( model, prompt, input_pos, - semantic_id=semantic_id, + semantic_ids=semantic_ids, **sampling_kwargs, ).view(num_samples, codebook_dim, -1) yield next_token.cpu() @@ -587,7 +604,7 @@ def generate_agent( input_pos, max_new_tokens - 1, im_end_id=im_end_id, - semantic_id=semantic_id, + semantic_ids=semantic_ids, decode_one_token=decode_one_token, early_stop_threshold=early_stop_threshold, **sampling_kwargs, @@ -602,65 +619,63 @@ def encode_tokens( num_codebooks=4, ): string = clean_text(string) - string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n" - new_tokens = tokenizer.encode( - string, - add_special_tokens=False, - max_length=10**6, - truncation=False, + messages = [] + messages.append( + Message( + role="user", + parts=[TextPart(text=string)], + cal_loss=False, + ) ) - tokens = torch.tensor([new_tokens], dtype=torch.int, device=device) - # Codebooks - zeros = ( - torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device) - * CODEBOOK_PAD_TOKEN_ID - ) - prompt = torch.cat((tokens, zeros), dim=0) + if prompt_tokens is not None: + if prompt_tokens.ndim == 3: + assert ( + prompt_tokens.shape[0] == 1 + ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)" + prompt_tokens = prompt_tokens[0] - if prompt_tokens is None: - return prompt + assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor" - # Get prompt tokens - if prompt_tokens.ndim == 3: - assert ( - prompt_tokens.shape[0] == 1 - ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)" - prompt_tokens = prompt_tokens[0] + if prompt_tokens.shape[0] > num_codebooks: + logger.warning( + f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks" + ) + prompt_tokens = prompt_tokens[:num_codebooks] - assert prompt_tokens.ndim == 2 - data = prompt_tokens + 1 + vq_part = VQPart(codes=prompt_tokens.to(device)) - if prompt_tokens.shape[0] > num_codebooks: - logger.warning( - f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks" + messages.append( + Message( + role="assistant", + parts=[TextPart(text="<|voice|>"), vq_part], + cal_loss=False, + ) + ) + else: + messages.append( + Message( + role="assistant", + parts=[TextPart(text="<|voice|>")], + cal_loss=False, + add_im_end=False, + ) ) - data = data[:num_codebooks] - - # Add pad token for each codebook - data = torch.cat( - (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)), - dim=1, - ) - # Since 1.0, we use <|semantic|> - s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>") - end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") - main_token_ids = ( - torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id + conversation = Conversation(messages=messages) + # conversation.visualize(tokenizer) + encoded = conversation.encode_for_inference( + tokenizer=tokenizer, + num_codebooks=num_codebooks, ) - main_token_ids[0, -1] = end_token_id - - data = torch.cat((main_token_ids, data), dim=0) - prompt = torch.cat((prompt, data), dim=1) - return prompt + return encoded.to(device) def load_model(checkpoint_path, device, precision, compile=False, is_agent=False): model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained( - checkpoint_path, load_weights=True + checkpoint_path, load_weights=True, is_agent=is_agent ) model = model.to(device=device, dtype=precision) @@ -729,11 +744,26 @@ def generate_long( model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) tokenizer = model.tokenizer - im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + im_end_id = tokenizer.get_token_id("<|im_end|>") encoded = [] texts = split_text(text, chunk_length) if iterative_prompt else [text] - encoded_prompts = [] + encoded_prompts = [ + Conversation( + messages=[ + Message( + role="system", + parts=[TextPart(text="Speak out the provided text.")], + cal_loss=False, + ) + ] + ) + .encode_for_inference( + tokenizer=tokenizer, + num_codebooks=model.config.num_codebooks, + ) + .to(device) + ] if use_prompt: for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)): @@ -812,7 +842,6 @@ def generate_long( model=model, prompt=cat_encoded, max_new_tokens=max_new_tokens, - im_end_id=im_end_id, decode_one_token=decode_one_token, temperature=temperature, top_p=top_p, @@ -842,12 +871,11 @@ def generate_long( ) # Put the generated tokens - # since there is and tokens, we remove last 2 tokens - codes = y[1:, prompt_length:-1].clone() - codes = codes - 1 + # since there is , we remove last token + codes = y[1:, prompt_length + 1 :].clone() assert (codes >= 0).all(), f"Negative code found" - decoded = y[:, prompt_length:-1].clone() + decoded = y[:, prompt_length:].clone() # But for global encoding, we should keep the token global_encoded.append(decoded) diff --git a/tools/schema.py b/tools/schema.py index 0698a009..c8813fc6 100644 --- a/tools/schema.py +++ b/tools/schema.py @@ -64,11 +64,14 @@ class ServeASRResponse(BaseModel): class ServeMessage(BaseModel): - role: Literal["system", "assistant", "user"] + role: Literal["system", "assistant", "user", "raw"] parts: list[ServeVQPart | ServeTextPart] def to_conversation_message(self): new_message = Message(role=self.role, parts=[]) + if self.role == "assistant": + new_message.modality = "voice" + for part in self.parts: if isinstance(part, ServeTextPart): new_message.parts.append(TextPart(text=part.text))