diff --git a/rvc/train/train.py b/rvc/train/train.py index 7a748b2d..a407a1f1 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -654,7 +654,7 @@ def train_and_evaluate( # loss_disc, _, _ = discriminator_loss_scaled(y_d_hat_r, y_d_hat_g) loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g) # Discriminator backward and update - epoch_disc_sum += loss_disc + epoch_disc_sum += loss_disc.item() optim_d.zero_grad() scaler.scale(loss_disc).backward() scaler.unscale_(optim_d) @@ -689,7 +689,7 @@ def train_and_evaluate( "value": loss_gen_all, "epoch": epoch, } - epoch_gen_sum += loss_gen_all + epoch_gen_sum += loss_gen_all.item() optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g)