-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix button height * Streaming support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Convert to 1 channel * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Conversion bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix target path * Add checkpoint selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix gpup decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
f473a75
commit 2711f5d
Showing
2 changed files
with
94 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -255,11 +255,11 @@ def show_selected(options): | |
from pydub import AudioSegment | ||
|
||
|
||
def convert_to_mono_in_place(audio_path): | ||
def convert_to_mono_in_place(audio_path: Path): | ||
audio = AudioSegment.from_file(audio_path) | ||
if audio.channels > 1: | ||
mono_audio = audio.set_channels(1) | ||
mono_audio.export(audio_path, format="mp3") | ||
mono_audio.export(audio_path, format=audio_path.suffix[1:]) | ||
logger.info(f"Convert {audio_path} successfully") | ||
|
||
|
||
|
@@ -277,12 +277,11 @@ def list_copy(list_file_path, method): | |
if target_wav_path.is_file(): | ||
continue | ||
target_wav_path.parent.mkdir(parents=True, exist_ok=True) | ||
convert_to_mono_in_place(original_wav_path) | ||
if method == i18n("Copy"): | ||
shutil.copy(original_wav_path, target_wav_path) | ||
else: | ||
shutil.move(original_wav_path, target_wav_path.parent) | ||
|
||
convert_to_mono_in_place(target_wav_path) | ||
original_lab_path = original_wav_path.with_suffix(".lab") | ||
target_lab_path = ( | ||
wav_root | ||
|
@@ -312,8 +311,16 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: | |
tar_path = data_path / item_path.name | ||
|
||
if content["type"] == "folder" and item_path.is_dir(): | ||
if content["method"] == i18n("Copy"): | ||
os.makedirs(tar_path, exist_ok=True) | ||
shutil.copytree( | ||
src=str(item_path), dst=str(tar_path), dirs_exist_ok=True | ||
) | ||
elif not tar_path.is_dir(): | ||
shutil.move(src=str(item_path), dst=str(tar_path)) | ||
|
||
for suf in ["wav", "flac", "mp3"]: | ||
for audio_path in item_path.glob(f"**/*.{suf}"): | ||
for audio_path in tar_path.glob(f"**/*.{suf}"): | ||
convert_to_mono_in_place(audio_path) | ||
|
||
cur_lang = content["label_lang"] | ||
|
@@ -328,9 +335,9 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: | |
"--device", | ||
label_device, | ||
"--audio-dir", | ||
item_path, | ||
tar_path, | ||
"--save-dir", | ||
item_path, | ||
tar_path, | ||
"--language", | ||
cur_lang, | ||
], | ||
|
@@ -339,14 +346,6 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: | |
except Exception: | ||
print("Transcription error occurred") | ||
|
||
if content["method"] == i18n("Copy"): | ||
os.makedirs(tar_path, exist_ok=True) | ||
shutil.copytree( | ||
src=str(item_path), dst=str(tar_path), dirs_exist_ok=True | ||
) | ||
elif not tar_path.is_dir(): | ||
shutil.move(src=str(item_path), dst=str(tar_path)) | ||
|
||
elif content["type"] == "file" and item_path.is_file(): | ||
list_copy(item_path, content["method"]) | ||
|
||
|
@@ -359,6 +358,7 @@ def train_process( | |
data_path: str, | ||
option: str, | ||
# vq-gan config | ||
vqgan_ckpt, | ||
vqgan_lr, | ||
vqgan_maxsteps, | ||
vqgan_data_num_workers, | ||
|
@@ -367,6 +367,7 @@ def train_process( | |
vqgan_precision, | ||
vqgan_check_interval, | ||
# llama config | ||
llama_ckpt, | ||
llama_base_config, | ||
llama_lr, | ||
llama_maxsteps, | ||
|
@@ -400,12 +401,29 @@ def generate_folder_name(): | |
str(data_pre_output.relative_to(cur_work_dir)), | ||
] | ||
) | ||
latest = list( | ||
sorted( | ||
[ | ||
str(p.relative_to("results")) | ||
for p in Path("results").glob("vqgan_*/") | ||
], | ||
reverse=True, | ||
) | ||
)[0] | ||
project = ( | ||
("vqgan_" + new_project) | ||
if vqgan_ckpt == "new" | ||
else latest | ||
if vqgan_ckpt == "latest" | ||
else vqgan_ckpt | ||
) | ||
logger.info(project) | ||
train_cmd = [ | ||
PYTHON, | ||
"fish_speech/train.py", | ||
"--config-name", | ||
"vqgan_finetune", | ||
f"project={'vqgan_' + new_project}", | ||
f"project={project}", | ||
f"trainer.strategy.process_group_backend={backend}", | ||
f"model.optimizer.lr={vqgan_lr}", | ||
f"trainer.max_steps={vqgan_maxsteps}", | ||
|
@@ -454,12 +472,30 @@ def generate_folder_name(): | |
if llama_base_config == "dual_ar_2_codebook_medium" | ||
else "text2semantic-sft-large-v1-4k.pth" | ||
) | ||
|
||
latest = list( | ||
sorted( | ||
[ | ||
str(p.relative_to("results")) | ||
for p in Path("results").glob("text2sem*/") | ||
], | ||
reverse=True, | ||
) | ||
)[0] | ||
project = ( | ||
("text2semantic_" + new_project) | ||
if llama_ckpt == "new" | ||
else latest | ||
if llama_ckpt == "latest" | ||
else llama_ckpt | ||
) | ||
logger.info(project) | ||
train_cmd = [ | ||
PYTHON, | ||
"fish_speech/train.py", | ||
"--config-name", | ||
"text2semantic_finetune", | ||
f"project={'text2semantic_' + new_project}", | ||
f"project={project}", | ||
f"ckpt_path=checkpoints/{ckpt_path}", | ||
f"trainer.strategy.process_group_backend={backend}", | ||
f"[email protected]={llama_base_config}", | ||
|
@@ -530,6 +566,18 @@ def fresh_vqgan_model(): | |
) | ||
|
||
|
||
def fresh_vqgan_ckpt(): | ||
return gr.Dropdown( | ||
choices=["latest", "new"] + [str(p) for p in Path("results").glob("vqgan_*/")] | ||
) | ||
|
||
|
||
def fresh_llama_ckpt(): | ||
return gr.Dropdown( | ||
choices=["latest", "new"] + [str(p) for p in Path("results").glob("text2sem*/")] | ||
) | ||
|
||
|
||
def fresh_llama_model(): | ||
return gr.Dropdown( | ||
choices=[init_llama_yml["ckpt_path"]] | ||
|
@@ -655,6 +703,14 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): | |
) | ||
with gr.Row(): | ||
with gr.Tab(label=i18n("VQGAN Configuration")): | ||
with gr.Row(equal_height=False): | ||
vqgan_ckpt = gr.Dropdown( | ||
label="Select VQGAN ckpt", | ||
choices=["latest", "new"] | ||
+ [str(p) for p in Path("results").glob("vqgan_*/")], | ||
value="latest", | ||
interactive=True, | ||
) | ||
with gr.Row(equal_height=False): | ||
vqgan_lr_slider = gr.Slider( | ||
label=i18n("Initial Learning Rate"), | ||
|
@@ -728,6 +784,13 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): | |
), | ||
value=True, | ||
) | ||
llama_ckpt = gr.Dropdown( | ||
label="Select LLAMA ckpt", | ||
choices=["latest", "new"] | ||
+ [str(p) for p in Path("results").glob("text2sem*/")], | ||
value="latest", | ||
interactive=True, | ||
) | ||
with gr.Row(equal_height=False): | ||
llama_lr_slider = gr.Slider( | ||
label=i18n("Initial Learning Rate"), | ||
|
@@ -1022,6 +1085,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): | |
train_box, | ||
model_type_radio, | ||
# vq-gan config | ||
vqgan_ckpt, | ||
vqgan_lr_slider, | ||
vqgan_maxsteps_slider, | ||
vqgan_data_num_workers_slider, | ||
|
@@ -1030,6 +1094,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): | |
vqgan_precision_dropdown, | ||
vqgan_check_interval_slider, | ||
# llama config | ||
llama_ckpt, | ||
llama_base_config, | ||
llama_lr_slider, | ||
llama_maxsteps_slider, | ||
|
@@ -1065,6 +1130,8 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): | |
fresh_btn.click( | ||
fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] | ||
) | ||
vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt]) | ||
llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt]) | ||
llama_lora_merge_btn.click( | ||
fn=llama_lora_merge, | ||
inputs=[llama_weight, lora_weight, llama_lora_output], | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters