diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py index ae28bcf4..05523d66 100644 --- a/fish_speech/webui/manage.py +++ b/fish_speech/webui/manage.py @@ -255,11 +255,11 @@ def show_selected(options): from pydub import AudioSegment -def convert_to_mono_in_place(audio_path): +def convert_to_mono_in_place(audio_path: Path): audio = AudioSegment.from_file(audio_path) if audio.channels > 1: mono_audio = audio.set_channels(1) - mono_audio.export(audio_path, format="mp3") + mono_audio.export(audio_path, format=audio_path.suffix[1:]) logger.info(f"Convert {audio_path} successfully") @@ -277,12 +277,11 @@ def list_copy(list_file_path, method): if target_wav_path.is_file(): continue target_wav_path.parent.mkdir(parents=True, exist_ok=True) - convert_to_mono_in_place(original_wav_path) if method == i18n("Copy"): shutil.copy(original_wav_path, target_wav_path) else: shutil.move(original_wav_path, target_wav_path.parent) - + convert_to_mono_in_place(target_wav_path) original_lab_path = original_wav_path.with_suffix(".lab") target_lab_path = ( wav_root @@ -312,8 +311,16 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: tar_path = data_path / item_path.name if content["type"] == "folder" and item_path.is_dir(): + if content["method"] == i18n("Copy"): + os.makedirs(tar_path, exist_ok=True) + shutil.copytree( + src=str(item_path), dst=str(tar_path), dirs_exist_ok=True + ) + elif not tar_path.is_dir(): + shutil.move(src=str(item_path), dst=str(tar_path)) + for suf in ["wav", "flac", "mp3"]: - for audio_path in item_path.glob(f"**/*.{suf}"): + for audio_path in tar_path.glob(f"**/*.{suf}"): convert_to_mono_in_place(audio_path) cur_lang = content["label_lang"] @@ -328,9 +335,9 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: "--device", label_device, "--audio-dir", - item_path, + tar_path, "--save-dir", - item_path, + tar_path, "--language", cur_lang, ], @@ -339,14 +346,6 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device: except Exception: print("Transcription error occurred") - if content["method"] == i18n("Copy"): - os.makedirs(tar_path, exist_ok=True) - shutil.copytree( - src=str(item_path), dst=str(tar_path), dirs_exist_ok=True - ) - elif not tar_path.is_dir(): - shutil.move(src=str(item_path), dst=str(tar_path)) - elif content["type"] == "file" and item_path.is_file(): list_copy(item_path, content["method"]) @@ -359,6 +358,7 @@ def train_process( data_path: str, option: str, # vq-gan config + vqgan_ckpt, vqgan_lr, vqgan_maxsteps, vqgan_data_num_workers, @@ -367,6 +367,7 @@ def train_process( vqgan_precision, vqgan_check_interval, # llama config + llama_ckpt, llama_base_config, llama_lr, llama_maxsteps, @@ -400,12 +401,29 @@ def generate_folder_name(): str(data_pre_output.relative_to(cur_work_dir)), ] ) + latest = list( + sorted( + [ + str(p.relative_to("results")) + for p in Path("results").glob("vqgan_*/") + ], + reverse=True, + ) + )[0] + project = ( + ("vqgan_" + new_project) + if vqgan_ckpt == "new" + else latest + if vqgan_ckpt == "latest" + else vqgan_ckpt + ) + logger.info(project) train_cmd = [ PYTHON, "fish_speech/train.py", "--config-name", "vqgan_finetune", - f"project={'vqgan_' + new_project}", + f"project={project}", f"trainer.strategy.process_group_backend={backend}", f"model.optimizer.lr={vqgan_lr}", f"trainer.max_steps={vqgan_maxsteps}", @@ -454,12 +472,30 @@ def generate_folder_name(): if llama_base_config == "dual_ar_2_codebook_medium" else "text2semantic-sft-large-v1-4k.pth" ) + + latest = list( + sorted( + [ + str(p.relative_to("results")) + for p in Path("results").glob("text2sem*/") + ], + reverse=True, + ) + )[0] + project = ( + ("text2semantic_" + new_project) + if llama_ckpt == "new" + else latest + if llama_ckpt == "latest" + else llama_ckpt + ) + logger.info(project) train_cmd = [ PYTHON, "fish_speech/train.py", "--config-name", "text2semantic_finetune", - f"project={'text2semantic_' + new_project}", + f"project={project}", f"ckpt_path=checkpoints/{ckpt_path}", f"trainer.strategy.process_group_backend={backend}", f"model@model.model={llama_base_config}", @@ -530,6 +566,18 @@ def fresh_vqgan_model(): ) +def fresh_vqgan_ckpt(): + return gr.Dropdown( + choices=["latest", "new"] + [str(p) for p in Path("results").glob("vqgan_*/")] + ) + + +def fresh_llama_ckpt(): + return gr.Dropdown( + choices=["latest", "new"] + [str(p) for p in Path("results").glob("text2sem*/")] + ) + + def fresh_llama_model(): return gr.Dropdown( choices=[init_llama_yml["ckpt_path"]] @@ -655,6 +703,14 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): ) with gr.Row(): with gr.Tab(label=i18n("VQGAN Configuration")): + with gr.Row(equal_height=False): + vqgan_ckpt = gr.Dropdown( + label="Select VQGAN ckpt", + choices=["latest", "new"] + + [str(p) for p in Path("results").glob("vqgan_*/")], + value="latest", + interactive=True, + ) with gr.Row(equal_height=False): vqgan_lr_slider = gr.Slider( label=i18n("Initial Learning Rate"), @@ -728,6 +784,13 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): ), value=True, ) + llama_ckpt = gr.Dropdown( + label="Select LLAMA ckpt", + choices=["latest", "new"] + + [str(p) for p in Path("results").glob("text2sem*/")], + value="latest", + interactive=True, + ) with gr.Row(equal_height=False): llama_lr_slider = gr.Slider( label=i18n("Initial Learning Rate"), @@ -1022,6 +1085,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): train_box, model_type_radio, # vq-gan config + vqgan_ckpt, vqgan_lr_slider, vqgan_maxsteps_slider, vqgan_data_num_workers_slider, @@ -1030,6 +1094,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): vqgan_precision_dropdown, vqgan_check_interval_slider, # llama config + llama_ckpt, llama_base_config, llama_lr_slider, llama_maxsteps_slider, @@ -1065,6 +1130,8 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output): fresh_btn.click( fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] ) + vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt]) + llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt]) llama_lora_merge_btn.click( fn=llama_lora_merge, inputs=[llama_weight, lora_weight, llama_lora_output], diff --git a/tools/webui.py b/tools/webui.py index eda42ad3..2c66a3e7 100644 --- a/tools/webui.py +++ b/tools/webui.py @@ -5,7 +5,7 @@ import queue import wave from argparse import ArgumentParser -from functools import partial +from functools import partial, wraps from pathlib import Path import gradio as gr @@ -38,17 +38,21 @@ """ TEXTBOX_PLACEHOLDER = i18n("Put your text here.") +SPACE_IMPORTED = False try: import spaces GPU_DECORATOR = spaces.GPU + SPACE_IMPORTED = True except ImportError: def GPU_DECORATOR(func): + @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) + wrapper.original = func # ref return wrapper @@ -169,6 +173,11 @@ def inference( inference_stream = partial(inference, streaming=True) +if not SPACE_IMPORTED: + logger.info("‘spaces’ not imported, use original") + inference = inference.original + inference_stream = partial(inference, streaming=True) + def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer = io.BytesIO()