From 0f98c6164b3cbe64888d002ea5bd5d868e1a8fba Mon Sep 17 00:00:00 2001 From: Blaise Date: Fri, 27 Dec 2024 16:27:13 +0100 Subject: [PATCH] proper way to import checkpoint --- rvc/lib/algorithm/discriminators.py | 12 ++++++------ rvc/lib/algorithm/generators/hifigan_mrf.py | 9 +++++---- rvc/lib/algorithm/generators/hifigan_nsf.py | 13 ++++++------- rvc/lib/algorithm/generators/refinegan.py | 8 ++++---- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/rvc/lib/algorithm/discriminators.py b/rvc/lib/algorithm/discriminators.py index 5c35d5c7..107ddcb1 100644 --- a/rvc/lib/algorithm/discriminators.py +++ b/rvc/lib/algorithm/discriminators.py @@ -1,6 +1,6 @@ import torch +from torch.utils.checkpoint import checkpoint from torch.nn.utils.parametrizations import spectral_norm, weight_norm -import torch.utils.checkpoint as checkpoint from rvc.lib.algorithm.commons import get_padding from rvc.lib.algorithm.residuals import LRELU_SLOPE @@ -53,7 +53,7 @@ def forward_discriminator(d, y, y_hat): y_d_g, fmap_g = d(y_hat) return y_d_r, fmap_r, y_d_g, fmap_g - y_d_r, fmap_r, y_d_g, fmap_g = checkpoint.checkpoint( + y_d_r, fmap_r, y_d_g, fmap_g = checkpoint( forward_discriminator, d, y, y_hat, use_reentrant=False ) else: @@ -97,8 +97,8 @@ def forward(self, x): fmap = [] for conv in self.convs: if self.training and self.checkpointing: - x = checkpoint.checkpoint(conv, x, use_reentrant=False) - x = checkpoint.checkpoint(self.lrelu, x, use_reentrant=False) + x = checkpoint(conv, x, use_reentrant=False) + x = checkpoint(self.lrelu, x, use_reentrant=False) else: x = self.lrelu(conv(x)) fmap.append(x) @@ -168,8 +168,8 @@ def forward(self, x): for conv in self.convs: if self.training and self.checkpointing: - x = checkpoint.checkpoint(conv, x, use_reentrant=False) - x = checkpoint.checkpoint(self.lrelu, x, use_reentrant=False) + x = checkpoint(conv, x, use_reentrant=False) + x = checkpoint(self.lrelu, x, use_reentrant=False) else: x = self.lrelu(conv(x)) fmap.append(x) diff --git a/rvc/lib/algorithm/generators/hifigan_mrf.py b/rvc/lib/algorithm/generators/hifigan_mrf.py index e3834ab8..2002673f 100644 --- a/rvc/lib/algorithm/generators/hifigan_mrf.py +++ b/rvc/lib/algorithm/generators/hifigan_mrf.py @@ -1,10 +1,11 @@ import math +from typing import Optional + import numpy as np import torch from torch.nn.utils import remove_weight_norm from torch.nn.utils.parametrizations import weight_norm -import torch.utils.checkpoint as checkpoint -from typing import Optional +from torch.utils.checkpoint import checkpoint LRELU_SLOPE = 0.1 @@ -351,7 +352,7 @@ def forward( x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) if self.training and self.checkpointing: - x = checkpoint.checkpoint(ups, x, use_reentrant=False) + x = checkpoint(ups, x, use_reentrant=False) else: x = ups(x) @@ -361,7 +362,7 @@ def mrf_sum(x, layers): return sum(layer(x) for layer in layers) / self.num_kernels if self.training and self.checkpointing: - x = checkpoint.checkpoint(mrf_sum, x, mrf, use_reentrant=False) + x = checkpoint(mrf_sum, x, mrf, use_reentrant=False) else: x = mrf_sum(x, mrf) diff --git a/rvc/lib/algorithm/generators/hifigan_nsf.py b/rvc/lib/algorithm/generators/hifigan_nsf.py index 1afdb91b..a6645cef 100644 --- a/rvc/lib/algorithm/generators/hifigan_nsf.py +++ b/rvc/lib/algorithm/generators/hifigan_nsf.py @@ -1,13 +1,14 @@ import math +from typing import Optional + import torch from torch.nn.utils import remove_weight_norm from torch.nn.utils.parametrizations import weight_norm -import torch.utils.checkpoint as checkpoint -from typing import Optional +from torch.utils.checkpoint import checkpoint +from rvc.lib.algorithm.commons import init_weights from rvc.lib.algorithm.generators.hifigan import SineGenerator from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock -from rvc.lib.algorithm.commons import init_weights class SourceModuleHnNSF(torch.nn.Module): @@ -185,7 +186,7 @@ def forward( # Apply upsampling layer if self.training and self.checkpointing: - x = checkpoint.checkpoint(ups, x, use_reentrant=False) + x = checkpoint(ups, x, use_reentrant=False) else: x = ups(x) @@ -200,9 +201,7 @@ def resblock_forward(x, blocks): # Checkpoint or regular computation for ResBlocks if self.training and self.checkpointing: - x = checkpoint.checkpoint( - resblock_forward, x, blocks, use_reentrant=False - ) + x = checkpoint(resblock_forward, x, blocks, use_reentrant=False) else: x = resblock_forward(x, blocks) diff --git a/rvc/lib/algorithm/generators/refinegan.py b/rvc/lib/algorithm/generators/refinegan.py index ca0b8254..e6296ff1 100644 --- a/rvc/lib/algorithm/generators/refinegan.py +++ b/rvc/lib/algorithm/generators/refinegan.py @@ -2,7 +2,7 @@ import torch from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations -import torch.utils.checkpoint as checkpoint +from torch.utils.checkpoint import checkpoint from rvc.lib.algorithm.commons import get_padding @@ -444,7 +444,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): x = torch.nn.functional.leaky_relu(x, self.leaky_relu_slope, inplace=True) downs.append(x) if self.training and self.checkpointing: - x = checkpoint.checkpoint(block, x, use_reentrant=False) + x = checkpoint(block, x, use_reentrant=False) else: x = block(x) @@ -464,9 +464,9 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): x = torch.nn.functional.leaky_relu(x, self.leaky_relu_slope, inplace=True) if self.training and self.checkpointing: - x = checkpoint.checkpoint(ups, x, use_reentrant=False) + x = checkpoint(ups, x, use_reentrant=False) x = torch.cat([x, down], dim=1) - x = checkpoint.checkpoint(res, x, use_reentrant=False) + x = checkpoint(res, x, use_reentrant=False) else: x = ups(x) x = torch.cat([x, down], dim=1)