Skip to content

Commit

Permalink
Merge pull request #936 from deiteris/remove-double-save
Browse files Browse the repository at this point in the history
Remove double torch.save
  • Loading branch information
blaisewf authored Jan 2, 2025
2 parents 5ec7c14 + caf462b commit e6df9a7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 50 deletions.
34 changes: 9 additions & 25 deletions rvc/train/process/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def extract_model(
sr,
pitch_guidance,
name,
model_dir,
model_path,
epoch,
step,
version,
Expand All @@ -38,22 +38,11 @@ def extract_model(
vocoder,
):
try:
model_dir = os.path.dirname(model_path)
os.makedirs(model_dir, exist_ok=True)

model_dir_path = os.path.dirname(model_dir)
os.makedirs(model_dir_path, exist_ok=True)

if "best_epoch" in model_dir:
pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth"
else:
pth_file = f"{name}_{epoch}e_{step}s.pth"

pth_file_old_version_path = os.path.join(
model_dir_path, f"{pth_file}_old_version.pth"
)

model_dir_path = os.path.dirname(model_dir)
if os.path.exists(os.path.join(model_dir_path, "model_info.json")):
with open(os.path.join(model_dir_path, "model_info.json"), "r") as f:
if os.path.exists(os.path.join(model_dir, "model_info.json")):
with open(os.path.join(model_dir, "model_info.json"), "r") as f:
data = json.load(f)
dataset_length = data.get("total_dataset_duration", None)
embedder_model = data.get("embedder_model", None)
Expand Down Expand Up @@ -108,23 +97,18 @@ def extract_model(
opt["speakers_id"] = speakers_id
opt["vocoder"] = vocoder

torch.save(opt, os.path.join(model_dir_path, pth_file))

# Create a backwards-compatible checkpoint
model = torch.load(model_dir, map_location=torch.device("cpu"))
torch.save(
replace_keys_in_dict(
replace_keys_in_dict(
model, ".parametrizations.weight.original1", ".weight_v"
opt, ".parametrizations.weight.original1", ".weight_v"
),
".parametrizations.weight.original0",
".weight_g",
),
pth_file_old_version_path,
model_path,
)
os.remove(model_dir)
os.rename(pth_file_old_version_path, model_dir)
print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})")

print(f"Saved model '{model_path}' (epoch {epoch} and step {step})")

except Exception as error:
print(f"An error occurred extracting the model: {error}")
36 changes: 19 additions & 17 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,30 +955,32 @@ def train_and_evaluate(
)
)

# Clean-up old best epochs
for m in model_del:
os.remove(m)

if model_add:
ckpt = (
net_g.module.state_dict()
if hasattr(net_g, "module")
else net_g.state_dict()
)
for m in model_add:
if not os.path.exists(m):
extract_model(
ckpt=ckpt,
sr=sample_rate,
pitch_guidance=True,
name=model_name,
model_dir=m,
epoch=epoch,
step=global_step,
version=version,
hps=hps,
overtrain_info=overtrain_info,
vocoder=vocoder,
)
# Clean-up old best epochs
for m in model_del:
os.remove(m)
if os.path.exists(m):
print(f'{m} already exists. Overwriting.')
extract_model(
ckpt=ckpt,
sr=sample_rate,
pitch_guidance=True,
name=model_name,
model_path=m,
epoch=epoch,
step=global_step,
version=version,
hps=hps,
overtrain_info=overtrain_info,
vocoder=vocoder,
)

# Check completion
if epoch >= custom_total_epoch:
Expand Down
15 changes: 7 additions & 8 deletions rvc/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,19 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
}
torch.save(checkpoint_data, checkpoint_path)

# Create a backwards-compatible checkpoint
old_version_path = checkpoint_path.replace(".pth", "_old_version.pth")
checkpoint_data = replace_keys_in_dict(
torch.save(
replace_keys_in_dict(
checkpoint_data, ".parametrizations.weight.original1", ".weight_v"
replace_keys_in_dict(
checkpoint_data, ".parametrizations.weight.original1", ".weight_v"
),
".parametrizations.weight.original0",
".weight_g",
),
".parametrizations.weight.original0",
".weight_g",
checkpoint_path
)
torch.save(checkpoint_data, old_version_path)

os.replace(old_version_path, checkpoint_path)
print(f"Saved model '{checkpoint_path}' (epoch {iteration})")


Expand Down

0 comments on commit e6df9a7

Please sign in to comment.