diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py index c016c667..a474efaf 100644 --- a/fish_speech/webui/manage.py +++ b/fish_speech/webui/manage.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import html import json import os @@ -180,8 +181,6 @@ def change_infer( infer_decoder_config, "--llama-checkpoint-path", infer_llama_model, - "--tokenizer", - "checkpoints/fish-speech-1.2", ] + (["--compile"] if infer_compile == "Yes" else []), env=env, @@ -400,6 +399,12 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: ) +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + def train_process( data_path: str, option: str, @@ -419,12 +424,6 @@ def train_process( llama_use_speaker, llama_use_lora, ): - import datetime - - def generate_folder_name(): - now = datetime.datetime.now() - folder_name = now.strftime("%Y%m%d_%H%M%S") - return folder_name backend = "nccl" if sys.platform == "linux" else "gloo" @@ -464,14 +463,9 @@ def generate_folder_name(): "16", ] ) - ckpt_path = ( - "text2semantic-sft-medium-v1.1-4k.pth" - if llama_base_config == "dual_ar_2_codebook_medium" - else "text2semantic-sft-large-v1.1-4k.pth" - ) + ckpt_path = "checkpoints/fish-speech-1.2/model.pth" lora_prefix = "lora_" if llama_use_lora else "" - llama_size = "large_" if ("large" in llama_base_config) else "medium_" - llama_name = lora_prefix + "text2semantic_" + llama_size + new_project + llama_name = lora_prefix + "text2semantic_" + new_project latest = next( iter( sorted( @@ -500,10 +494,7 @@ def generate_folder_name(): "--config-name", "text2semantic_finetune", f"project={project}", - f"ckpt_path=checkpoints/{ckpt_path}", f"trainer.strategy.process_group_backend={backend}", - f"model@model.model={llama_base_config}", - "tokenizer.pretrained_model_name_or_path=checkpoints", f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", f"model.optimizer.lr={llama_lr}", @@ -514,8 +505,8 @@ def generate_folder_name(): f"trainer.precision={llama_precision}", f"trainer.val_check_interval={llama_check_interval}", f"trainer.accumulate_grad_batches={llama_grad_batches}", - f"train_dataset.use_speaker={llama_use_speaker}", - ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else []) + f"train_dataset.interactive_prob={llama_use_speaker}", + ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) logger.info(train_cmd) subprocess.run(train_cmd) @@ -573,10 +564,7 @@ def list_decoder_models(): def list_llama_models(): - choices = [ - str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*") - ] - choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")] + choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")] if not choices: logger.warning("No LLaMA model found") return choices @@ -627,16 +615,12 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou merge_cmd = [ PYTHON, "tools/llama/merge_lora.py", - "--llama-config", - lora_llama_config, "--lora-config", "r_8_alpha_16", - "--llama-weight", - llama_weight, "--lora-weight", lora_weight, "--output", - llama_lora_output, + llama_lora_output + "_" + generate_folder_name(), ] logger.info(merge_cmd) subprocess.run(merge_cmd) @@ -759,6 +743,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou "Use LoRA can save GPU memory, but may reduce the quality of the model" ), value=True, + interactive=False, ) llama_ckpt = gr.Dropdown( label=i18n("Select LLAMA ckpt"), @@ -792,7 +777,6 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou llama_base_config = gr.Dropdown( label=i18n("Model Size"), choices=[ - "text2semantic_agent", "text2semantic_finetune", ], value="text2semantic_finetune", @@ -865,7 +849,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou maximum=1.0, step=0.05, value=init_llama_yml["train_dataset"][ - "use_speaker" + "interactive_prob" ], ) @@ -879,7 +863,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou choices=[ "checkpoints/fish-speech-1.2/model.pth", ], - value=init_llama_yml["ckpt_path"], + value="checkpoints/fish-speech-1.2/model.pth", allow_custom_value=True, interactive=True, ) @@ -902,10 +886,9 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou "Type the path or select from the dropdown" ), choices=[ - "text2semantic_agent", "text2semantic_finetune", ], - value="text2semantic_agent", + value="text2semantic_finetune", allow_custom_value=True, ) with gr.Row(equal_height=False): @@ -914,8 +897,8 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou info=i18n( "Type the path or select from the dropdown" ), - value="checkpoints/merged.ckpt", - choices=["checkpoints/merged.ckpt"], + value="checkpoints/merged", + choices=["checkpoints/merged"], allow_custom_value=True, interactive=True, )