From 629cac1960c9ffeb022904190e19770125c1f00f Mon Sep 17 00:00:00 2001 From: Blaise Date: Fri, 27 Dec 2024 13:51:22 +0100 Subject: [PATCH] minor changes on save --- rvc/train/process/extract_model.py | 117 +++++++++++++------------ rvc/train/process/model_information.py | 9 +- rvc/train/train.py | 50 +++++------ 3 files changed, 91 insertions(+), 85 deletions(-) diff --git a/rvc/train/process/extract_model.py b/rvc/train/process/extract_model.py index 18f8b9be..4418795d 100644 --- a/rvc/train/process/extract_model.py +++ b/rvc/train/process/extract_model.py @@ -1,5 +1,4 @@ -import os -import sys +import os, sys import torch import hashlib import datetime @@ -11,14 +10,15 @@ def replace_keys_in_dict(d, old_key_part, new_key_part): - updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {} + if isinstance(d, OrderedDict): + updated_dict = OrderedDict() + else: + updated_dict = {} for key, value in d.items(): new_key = key.replace(old_key_part, new_key_part) - updated_dict[new_key] = ( - replace_keys_in_dict(value, old_key_part, new_key_part) - if isinstance(value, dict) - else value - ) + if isinstance(value, dict): + value = replace_keys_in_dict(value, old_key_part, new_key_part) + updated_dict[new_key] = value return updated_dict @@ -36,84 +36,93 @@ def extract_model( vocoder, ): try: - print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})") model_dir_path = os.path.dirname(model_dir) os.makedirs(model_dir_path, exist_ok=True) - suffix = "_best_epoch" if "best_epoch" in model_dir else "" - pth_file = f"{name}_{epoch}e_{step}s{suffix}.pth" + if "best_epoch" in model_dir: + pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth" + else: + pth_file = f"{name}_{epoch}e_{step}s.pth" + pth_file_old_version_path = os.path.join( model_dir_path, f"{pth_file}_old_version.pth" ) - dataset_length, embedder_model, speakers_id = None, None, 1 + model_dir_path = os.path.dirname(model_dir) if os.path.exists(os.path.join(model_dir_path, "model_info.json")): with open(os.path.join(model_dir_path, "model_info.json"), "r") as f: data = json.load(f) - dataset_length = data.get("total_dataset_duration") - embedder_model = data.get("embedder_model") + dataset_lenght = data.get("total_dataset_duration", None) + embedder_model = data.get("embedder_model", None) speakers_id = data.get("speakers_id", 1) + else: + dataset_lenght = None with open(os.path.join(now_dir, "assets", "config.json"), "r") as f: data = json.load(f) - model_author = data.get("model_author") + model_author = data.get("model_author", None) opt = OrderedDict( weight={ key: value.half() for key, value in ckpt.items() if "enc_q" not in key - }, - config=[ - hps.data.filter_length // 2 + 1, - 32, - hps.model.inter_channels, - hps.model.hidden_channels, - hps.model.filter_channels, - hps.model.n_heads, - hps.model.n_layers, - hps.model.kernel_size, - hps.model.p_dropout, - hps.model.resblock, - hps.model.resblock_kernel_sizes, - hps.model.resblock_dilation_sizes, - hps.model.upsample_rates, - hps.model.upsample_initial_channel, - hps.model.upsample_kernel_sizes, - hps.model.spk_embed_dim, - hps.model.gin_channels, - hps.data.sample_rate, - ], - epoch=epoch, - step=step, - sr=sr, - f0=pitch_guidance, - version=version, - creation_date=datetime.datetime.now().isoformat(), - overtrain_info=overtrain_info, - dataset_length=dataset_length, - model_name=name, - author=model_author, - embedder_model=embedder_model, - speakers_id=speakers_id, - vocoder=vocoder, + } ) + opt["config"] = [ + hps.data.filter_length // 2 + 1, + 32, + hps.model.inter_channels, + hps.model.hidden_channels, + hps.model.filter_channels, + hps.model.n_heads, + hps.model.n_layers, + hps.model.kernel_size, + hps.model.p_dropout, + hps.model.resblock, + hps.model.resblock_kernel_sizes, + hps.model.resblock_dilation_sizes, + hps.model.upsample_rates, + hps.model.upsample_initial_channel, + hps.model.upsample_kernel_sizes, + hps.model.spk_embed_dim, + hps.model.gin_channels, + hps.data.sample_rate, + ] + + opt["epoch"] = epoch + opt["step"] = step + opt["sr"] = sr + opt["f0"] = pitch_guidance + opt["version"] = version + opt["creation_date"] = datetime.datetime.now().isoformat() hash_input = f"{name}-{epoch}-{step}-{sr}-{version}-{opt['config']}" opt["model_hash"] = hashlib.sha256(hash_input.encode()).hexdigest() + opt["overtrain_info"] = overtrain_info + opt["dataset_lenght"] = dataset_lenght + opt["model_name"] = name + opt["author"] = model_author + opt["embedder_model"] = embedder_model + opt["speakers_id"] = speakers_id + opt["vocoder"] = vocoder torch.save(opt, os.path.join(model_dir_path, pth_file)) + # Create a backwards-compatible checkpoint model = torch.load(model_dir, map_location=torch.device("cpu")) - updated_model = replace_keys_in_dict( + torch.save( replace_keys_in_dict( - model, ".parametrizations.weight.original1", ".weight_v" + replace_keys_in_dict( + model, ".parametrizations.weight.original1", ".weight_v" + ), + ".parametrizations.weight.original0", + ".weight_g", ), - ".parametrizations.weight.original0", - ".weight_g", + pth_file_old_version_path, ) - torch.save(updated_model, pth_file_old_version_path) os.remove(model_dir) os.rename(pth_file_old_version_path, model_dir) + print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})") except Exception as error: print(f"An error occurred extracting the model: {error}") diff --git a/rvc/train/process/model_information.py b/rvc/train/process/model_information.py index a61cb2e2..fdd1c580 100644 --- a/rvc/train/process/model_information.py +++ b/rvc/train/process/model_information.py @@ -23,7 +23,7 @@ def model_information(path): sr = model_data.get("sr", "None") f0 = model_data.get("f0", "None") dataset_length = model_data.get("dataset_length", "None") - version = model_data.get("version", "None") + vocoder = model_data.get("vocoder", "None") creation_date = model_data.get("creation_date", "None") model_hash = model_data.get("model_hash", None) overtrain_info = model_data.get("overtrain_info", "None") @@ -31,8 +31,6 @@ def model_information(path): embedder_model = model_data.get("embedder_model", "None") speakers_id = model_data.get("speakers_id", 0) - pitch_guidance = "True" if f0 == 1 else "False" - creation_date_str = prettify_date(creation_date) if creation_date else "None" return ( @@ -40,13 +38,12 @@ def model_information(path): f"Model Creator: {model_author}\n" f"Epochs: {epochs}\n" f"Steps: {steps}\n" - f"Model Architecture: {version}\n" + f"Vocoder: {vocoder}\n" f"Sampling Rate: {sr}\n" - f"Pitch Guidance: {pitch_guidance}\n" f"Dataset Length: {dataset_length}\n" f"Creation Date: {creation_date_str}\n" - f"Hash (ID): {model_hash}\n" f"Overtrain Info: {overtrain_info}\n" f"Embedder Model: {embedder_model}\n" f"Max Speakers ID: {speakers_id}" + f"Hash: {model_hash}\n" ) diff --git a/rvc/train/train.py b/rvc/train/train.py index d84ce2eb..061fff59 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -881,31 +881,6 @@ def train_and_evaluate( ) ) - # Check completion - if epoch >= custom_total_epoch: - lowest_value_rounded = float(lowest_value["value"]) - lowest_value_rounded = round(lowest_value_rounded, 3) - print( - f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen." - ) - print( - f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}" - ) - - pid_file_path = os.path.join(experiment_dir, "config.json") - with open(pid_file_path, "r") as pid_file: - pid_data = json.load(pid_file) - with open(pid_file_path, "w") as pid_file: - pid_data.pop("process_pids", None) - json.dump(pid_data, pid_file, indent=4) - # Final model - model_add.append( - os.path.join( - experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth" - ) - ) - done = True - # Print training progress lowest_value_rounded = float(lowest_value["value"]) lowest_value_rounded = round(lowest_value_rounded, 3) @@ -977,6 +952,31 @@ def train_and_evaluate( for m in model_del: os.remove(m) + # Check completion + if epoch >= custom_total_epoch: + lowest_value_rounded = float(lowest_value["value"]) + lowest_value_rounded = round(lowest_value_rounded, 3) + print( + f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen." + ) + print( + f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}" + ) + + pid_file_path = os.path.join(experiment_dir, "config.json") + with open(pid_file_path, "r") as pid_file: + pid_data = json.load(pid_file) + with open(pid_file_path, "w") as pid_file: + pid_data.pop("process_pids", None) + json.dump(pid_data, pid_file, indent=4) + # Final model + model_add.append( + os.path.join( + experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth" + ) + ) + done = True + if done: os._exit(2333333)