diff --git a/rvc/train/process/extract_model.py b/rvc/train/process/extract_model.py index ab536a15..5d3d1170 100644 --- a/rvc/train/process/extract_model.py +++ b/rvc/train/process/extract_model.py @@ -29,7 +29,7 @@ def extract_model( sr, pitch_guidance, name, - model_dir, + model_path, epoch, step, version, @@ -38,22 +38,11 @@ def extract_model( vocoder, ): try: + model_dir = os.path.dirname(model_path) + os.makedirs(model_dir, exist_ok=True) - model_dir_path = os.path.dirname(model_dir) - os.makedirs(model_dir_path, exist_ok=True) - - if "best_epoch" in model_dir: - pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth" - else: - pth_file = f"{name}_{epoch}e_{step}s.pth" - - pth_file_old_version_path = os.path.join( - model_dir_path, f"{pth_file}_old_version.pth" - ) - - model_dir_path = os.path.dirname(model_dir) - if os.path.exists(os.path.join(model_dir_path, "model_info.json")): - with open(os.path.join(model_dir_path, "model_info.json"), "r") as f: + if os.path.exists(os.path.join(model_dir, "model_info.json")): + with open(os.path.join(model_dir, "model_info.json"), "r") as f: data = json.load(f) dataset_length = data.get("total_dataset_duration", None) embedder_model = data.get("embedder_model", None) @@ -108,23 +97,18 @@ def extract_model( opt["speakers_id"] = speakers_id opt["vocoder"] = vocoder - torch.save(opt, os.path.join(model_dir_path, pth_file)) - - # Create a backwards-compatible checkpoint - model = torch.load(model_dir, map_location=torch.device("cpu")) torch.save( replace_keys_in_dict( replace_keys_in_dict( - model, ".parametrizations.weight.original1", ".weight_v" + opt, ".parametrizations.weight.original1", ".weight_v" ), ".parametrizations.weight.original0", ".weight_g", ), - pth_file_old_version_path, + model_path, ) - os.remove(model_dir) - os.rename(pth_file_old_version_path, model_dir) - print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})") + + print(f"Saved model '{model_path}' (epoch {epoch} and step {step})") except Exception as error: print(f"An error occurred extracting the model: {error}") diff --git a/rvc/train/train.py b/rvc/train/train.py index 94f83f49..8e6365f6 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -960,7 +960,7 @@ def train_and_evaluate( sr=sample_rate, pitch_guidance=True, name=model_name, - model_dir=m, + model_path=m, epoch=epoch, step=global_step, version=version, diff --git a/rvc/train/utils.py b/rvc/train/utils.py index 0de664e8..06e72bdd 100644 --- a/rvc/train/utils.py +++ b/rvc/train/utils.py @@ -102,20 +102,19 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) "optimizer": optimizer.state_dict(), "learning_rate": learning_rate, } - torch.save(checkpoint_data, checkpoint_path) # Create a backwards-compatible checkpoint - old_version_path = checkpoint_path.replace(".pth", "_old_version.pth") - checkpoint_data = replace_keys_in_dict( + torch.save( replace_keys_in_dict( - checkpoint_data, ".parametrizations.weight.original1", ".weight_v" + replace_keys_in_dict( + checkpoint_data, ".parametrizations.weight.original1", ".weight_v" + ), + ".parametrizations.weight.original0", + ".weight_g", ), - ".parametrizations.weight.original0", - ".weight_g", + checkpoint_path ) - torch.save(checkpoint_data, old_version_path) - os.replace(old_version_path, checkpoint_path) print(f"Saved model '{checkpoint_path}' (epoch {iteration})")