-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Windows Setup Help * Optimize documents/bootscripts for Windows User * Correct some description * Fix dependecies * fish 1.2 webui & api * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix spelling * Fix CUDA env * Update api usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adapt finetuning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
74d7850
commit 3d6d1d7
Showing
1 changed file
with
19 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
import html | ||
import json | ||
import os | ||
|
@@ -180,8 +181,6 @@ def change_infer( | |
infer_decoder_config, | ||
"--llama-checkpoint-path", | ||
infer_llama_model, | ||
"--tokenizer", | ||
"checkpoints/fish-speech-1.2", | ||
] | ||
+ (["--compile"] if infer_compile == "Yes" else []), | ||
env=env, | ||
|
@@ -400,6 +399,12 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: | |
) | ||
|
||
|
||
def generate_folder_name(): | ||
now = datetime.datetime.now() | ||
folder_name = now.strftime("%Y%m%d_%H%M%S") | ||
return folder_name | ||
|
||
|
||
def train_process( | ||
data_path: str, | ||
option: str, | ||
|
@@ -419,12 +424,6 @@ def train_process( | |
llama_use_speaker, | ||
llama_use_lora, | ||
): | ||
import datetime | ||
|
||
def generate_folder_name(): | ||
now = datetime.datetime.now() | ||
folder_name = now.strftime("%Y%m%d_%H%M%S") | ||
return folder_name | ||
|
||
backend = "nccl" if sys.platform == "linux" else "gloo" | ||
|
||
|
@@ -464,14 +463,9 @@ def generate_folder_name(): | |
"16", | ||
] | ||
) | ||
ckpt_path = ( | ||
"text2semantic-sft-medium-v1.1-4k.pth" | ||
if llama_base_config == "dual_ar_2_codebook_medium" | ||
else "text2semantic-sft-large-v1.1-4k.pth" | ||
) | ||
ckpt_path = "checkpoints/fish-speech-1.2/model.pth" | ||
lora_prefix = "lora_" if llama_use_lora else "" | ||
llama_size = "large_" if ("large" in llama_base_config) else "medium_" | ||
llama_name = lora_prefix + "text2semantic_" + llama_size + new_project | ||
llama_name = lora_prefix + "text2semantic_" + new_project | ||
latest = next( | ||
iter( | ||
sorted( | ||
|
@@ -500,10 +494,7 @@ def generate_folder_name(): | |
"--config-name", | ||
"text2semantic_finetune", | ||
f"project={project}", | ||
f"ckpt_path=checkpoints/{ckpt_path}", | ||
f"trainer.strategy.process_group_backend={backend}", | ||
f"[email protected]={llama_base_config}", | ||
"tokenizer.pretrained_model_name_or_path=checkpoints", | ||
f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", | ||
f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", | ||
f"model.optimizer.lr={llama_lr}", | ||
|
@@ -514,8 +505,8 @@ def generate_folder_name(): | |
f"trainer.precision={llama_precision}", | ||
f"trainer.val_check_interval={llama_check_interval}", | ||
f"trainer.accumulate_grad_batches={llama_grad_batches}", | ||
f"train_dataset.use_speaker={llama_use_speaker}", | ||
] + ([f"[email protected]_config=r_8_alpha_16"] if llama_use_lora else []) | ||
f"train_dataset.interactive_prob={llama_use_speaker}", | ||
] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) | ||
logger.info(train_cmd) | ||
subprocess.run(train_cmd) | ||
|
||
|
@@ -573,10 +564,7 @@ def list_decoder_models(): | |
|
||
|
||
def list_llama_models(): | ||
choices = [ | ||
str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*") | ||
] | ||
choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")] | ||
choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")] | ||
if not choices: | ||
logger.warning("No LLaMA model found") | ||
return choices | ||
|
@@ -627,16 +615,12 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
merge_cmd = [ | ||
PYTHON, | ||
"tools/llama/merge_lora.py", | ||
"--llama-config", | ||
lora_llama_config, | ||
"--lora-config", | ||
"r_8_alpha_16", | ||
"--llama-weight", | ||
llama_weight, | ||
"--lora-weight", | ||
lora_weight, | ||
"--output", | ||
llama_lora_output, | ||
llama_lora_output + "_" + generate_folder_name(), | ||
] | ||
logger.info(merge_cmd) | ||
subprocess.run(merge_cmd) | ||
|
@@ -759,6 +743,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
"Use LoRA can save GPU memory, but may reduce the quality of the model" | ||
), | ||
value=True, | ||
interactive=False, | ||
) | ||
llama_ckpt = gr.Dropdown( | ||
label=i18n("Select LLAMA ckpt"), | ||
|
@@ -792,7 +777,6 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
llama_base_config = gr.Dropdown( | ||
label=i18n("Model Size"), | ||
choices=[ | ||
"text2semantic_agent", | ||
"text2semantic_finetune", | ||
], | ||
value="text2semantic_finetune", | ||
|
@@ -865,7 +849,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
maximum=1.0, | ||
step=0.05, | ||
value=init_llama_yml["train_dataset"][ | ||
"use_speaker" | ||
"interactive_prob" | ||
], | ||
) | ||
|
||
|
@@ -879,7 +863,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
choices=[ | ||
"checkpoints/fish-speech-1.2/model.pth", | ||
], | ||
value=init_llama_yml["ckpt_path"], | ||
value="checkpoints/fish-speech-1.2/model.pth", | ||
allow_custom_value=True, | ||
interactive=True, | ||
) | ||
|
@@ -902,10 +886,9 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
"Type the path or select from the dropdown" | ||
), | ||
choices=[ | ||
"text2semantic_agent", | ||
"text2semantic_finetune", | ||
], | ||
value="text2semantic_agent", | ||
value="text2semantic_finetune", | ||
allow_custom_value=True, | ||
) | ||
with gr.Row(equal_height=False): | ||
|
@@ -914,8 +897,8 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou | |
info=i18n( | ||
"Type the path or select from the dropdown" | ||
), | ||
value="checkpoints/merged.ckpt", | ||
choices=["checkpoints/merged.ckpt"], | ||
value="checkpoints/merged", | ||
choices=["checkpoints/merged"], | ||
allow_custom_value=True, | ||
interactive=True, | ||
) | ||
|