Skip to content

Commit

Permalink
attempt to fix multi-gpu use as it was in v3.2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
AznamirWoW committed Dec 28, 2024
1 parent fb5fb93 commit 39f6469
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,18 @@ def run(
sr=sample_rate,
vocoder=vocoder,
checkpointing=checkpointing,
).to(device)

)
net_d = MultiPeriodDiscriminator(
version, config.model.use_spectral_norm, checkpointing=checkpointing
).to(device)
)

if torch.cuda.is_available():
net_g = net_g.cuda(rank)
net_d = net_d.cuda(rank)
else:
net_g.to(device)
net_d.to(device)

optim_g = torch.optim.AdamW(
net_g.parameters(),
Expand Down Expand Up @@ -488,13 +495,22 @@ def run(
else:
for info in train_loader:
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
reference = (
phone.to(device),
phone_lengths.to(device),
pitch.to(device),
pitchf.to(device),
sid.to(device),
)
if device.type == "cuda":
reference = (
phone.cuda(rank, non_blocking=True),
phone_lengths.cuda(rank, non_blocking=True),
pitch.cuda(rank, non_blocking=True),
pitchf.cuda(rank, non_blocking=True),
sid.cuda(rank, non_blocking=True),
)
else:
reference = (
phone.to(device),
phone_lengths.to(device),
pitch.to(device),
pitchf.to(device),
sid.to(device),
)
break

for epoch in range(epoch_str, total_epoch + 1):
Expand Down Expand Up @@ -615,12 +631,12 @@ def train_and_evaluate(
model_output
)
# slice of the original waveform to match a generate slice
wave = commons.slice_segments(
wave,
ids_slice * config.data.hop_length,
config.train.segment_size,
dim=3,
)
wave = commons.slice_segments(
wave,
ids_slice * config.data.hop_length,
config.train.segment_size,
dim=3,
)
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False):
# if vocoder == "HiFi-GAN":
Expand Down Expand Up @@ -735,12 +751,12 @@ def train_and_evaluate(
config.data.mel_fmax,
)
# used for tensorboard chart - slice/mel_org
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
# used for tensorboard chart - slice/mel_gen
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
Expand Down

0 comments on commit 39f6469

Please sign in to comment.