Skip to content

Commit

Permalink
Remove double torch.save and load model
Browse files Browse the repository at this point in the history
  • Loading branch information
deiteris committed Jan 1, 2025
1 parent 253dafc commit e9f1deb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 34 deletions.
34 changes: 9 additions & 25 deletions rvc/train/process/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def extract_model(
sr,
pitch_guidance,
name,
model_dir,
model_path,
epoch,
step,
version,
Expand All @@ -38,22 +38,11 @@ def extract_model(
vocoder,
):
try:
model_dir = os.path.dirname(model_path)
os.makedirs(model_dir, exist_ok=True)

model_dir_path = os.path.dirname(model_dir)
os.makedirs(model_dir_path, exist_ok=True)

if "best_epoch" in model_dir:
pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth"
else:
pth_file = f"{name}_{epoch}e_{step}s.pth"

pth_file_old_version_path = os.path.join(
model_dir_path, f"{pth_file}_old_version.pth"
)

model_dir_path = os.path.dirname(model_dir)
if os.path.exists(os.path.join(model_dir_path, "model_info.json")):
with open(os.path.join(model_dir_path, "model_info.json"), "r") as f:
if os.path.exists(os.path.join(model_dir, "model_info.json")):
with open(os.path.join(model_dir, "model_info.json"), "r") as f:
data = json.load(f)
dataset_length = data.get("total_dataset_duration", None)
embedder_model = data.get("embedder_model", None)
Expand Down Expand Up @@ -108,23 +97,18 @@ def extract_model(
opt["speakers_id"] = speakers_id
opt["vocoder"] = vocoder

torch.save(opt, os.path.join(model_dir_path, pth_file))

# Create a backwards-compatible checkpoint
model = torch.load(model_dir, map_location=torch.device("cpu"))
torch.save(
replace_keys_in_dict(
replace_keys_in_dict(
model, ".parametrizations.weight.original1", ".weight_v"
opt, ".parametrizations.weight.original1", ".weight_v"
),
".parametrizations.weight.original0",
".weight_g",
),
pth_file_old_version_path,
model_path,
)
os.remove(model_dir)
os.rename(pth_file_old_version_path, model_dir)
print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})")

print(f"Saved model '{model_path}' (epoch {epoch} and step {step})")

except Exception as error:
print(f"An error occurred extracting the model: {error}")
2 changes: 1 addition & 1 deletion rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def train_and_evaluate(
sr=sample_rate,
pitch_guidance=True,
name=model_name,
model_dir=m,
model_path=m,
epoch=epoch,
step=global_step,
version=version,
Expand Down
15 changes: 7 additions & 8 deletions rvc/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,19 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
}
torch.save(checkpoint_data, checkpoint_path)

# Create a backwards-compatible checkpoint
old_version_path = checkpoint_path.replace(".pth", "_old_version.pth")
checkpoint_data = replace_keys_in_dict(
torch.save(
replace_keys_in_dict(
checkpoint_data, ".parametrizations.weight.original1", ".weight_v"
replace_keys_in_dict(
checkpoint_data, ".parametrizations.weight.original1", ".weight_v"
),
".parametrizations.weight.original0",
".weight_g",
),
".parametrizations.weight.original0",
".weight_g",
checkpoint_path
)
torch.save(checkpoint_data, old_version_path)

os.replace(old_version_path, checkpoint_path)
print(f"Saved model '{checkpoint_path}' (epoch {iteration})")


Expand Down

0 comments on commit e9f1deb

Please sign in to comment.