diff --git a/rvc/train/train.py b/rvc/train/train.py index 9f3acb7b..f3086415 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -492,13 +492,15 @@ def run( # get the first sample as reference for tensorboard evaluation for info in train_loader: phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info - reference = (phone.to(device), - phone_lengths.to(device), - pitch.to(device) if pitch_guidance else None, - pitchf.to(device) if pitch_guidance else None, - sid.to(device)) + reference = ( + phone.to(device), + phone_lengths.to(device), + pitch.to(device) if pitch_guidance else None, + pitchf.to(device) if pitch_guidance else None, + sid.to(device), + ) break - + for epoch in range(epoch_str, total_epoch + 1): if rank == 0: train_and_evaluate( @@ -514,7 +516,7 @@ def run( custom_save_every_weights, custom_total_epoch, device, - reference + reference, ) else: train_and_evaluate( @@ -791,10 +793,10 @@ def train_and_evaluate( ), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), } - + with torch.no_grad(): o, *_ = net_g.infer(*reference) - audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, : ]} + audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]} summarize( writer=writer,