diff --git a/rvc/train/train.py b/rvc/train/train.py index 92a16a94..b51288bc 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -420,8 +420,8 @@ def run( net_g = net_g.cuda(device_id) net_d = net_d.cuda(device_id) else: - net_g.to(device) - net_d.to(device) + net_g = net_g.to(device) + net_d = net_d.to(device) if optimizer == "AdamW": optimizer = torch.optim.AdamW