Skip to content

Commit

Permalink
proper way to import checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf committed Dec 27, 2024
1 parent 11db12a commit 0f98c61
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
12 changes: 6 additions & 6 deletions rvc/lib/algorithm/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch.utils.checkpoint import checkpoint
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
import torch.utils.checkpoint as checkpoint

from rvc.lib.algorithm.commons import get_padding
from rvc.lib.algorithm.residuals import LRELU_SLOPE
Expand Down Expand Up @@ -53,7 +53,7 @@ def forward_discriminator(d, y, y_hat):
y_d_g, fmap_g = d(y_hat)
return y_d_r, fmap_r, y_d_g, fmap_g

y_d_r, fmap_r, y_d_g, fmap_g = checkpoint.checkpoint(
y_d_r, fmap_r, y_d_g, fmap_g = checkpoint(
forward_discriminator, d, y, y_hat, use_reentrant=False
)
else:
Expand Down Expand Up @@ -97,8 +97,8 @@ def forward(self, x):
fmap = []
for conv in self.convs:
if self.training and self.checkpointing:
x = checkpoint.checkpoint(conv, x, use_reentrant=False)
x = checkpoint.checkpoint(self.lrelu, x, use_reentrant=False)
x = checkpoint(conv, x, use_reentrant=False)
x = checkpoint(self.lrelu, x, use_reentrant=False)
else:
x = self.lrelu(conv(x))
fmap.append(x)
Expand Down Expand Up @@ -168,8 +168,8 @@ def forward(self, x):

for conv in self.convs:
if self.training and self.checkpointing:
x = checkpoint.checkpoint(conv, x, use_reentrant=False)
x = checkpoint.checkpoint(self.lrelu, x, use_reentrant=False)
x = checkpoint(conv, x, use_reentrant=False)
x = checkpoint(self.lrelu, x, use_reentrant=False)
else:
x = self.lrelu(conv(x))
fmap.append(x)
Expand Down
9 changes: 5 additions & 4 deletions rvc/lib/algorithm/generators/hifigan_mrf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
from typing import Optional

import numpy as np
import torch
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
import torch.utils.checkpoint as checkpoint
from typing import Optional
from torch.utils.checkpoint import checkpoint

LRELU_SLOPE = 0.1

Expand Down Expand Up @@ -351,7 +352,7 @@ def forward(
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)

if self.training and self.checkpointing:
x = checkpoint.checkpoint(ups, x, use_reentrant=False)
x = checkpoint(ups, x, use_reentrant=False)
else:
x = ups(x)

Expand All @@ -361,7 +362,7 @@ def mrf_sum(x, layers):
return sum(layer(x) for layer in layers) / self.num_kernels

if self.training and self.checkpointing:
x = checkpoint.checkpoint(mrf_sum, x, mrf, use_reentrant=False)
x = checkpoint(mrf_sum, x, mrf, use_reentrant=False)
else:
x = mrf_sum(x, mrf)

Expand Down
13 changes: 6 additions & 7 deletions rvc/lib/algorithm/generators/hifigan_nsf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math
from typing import Optional

import torch
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
import torch.utils.checkpoint as checkpoint
from typing import Optional
from torch.utils.checkpoint import checkpoint

from rvc.lib.algorithm.commons import init_weights
from rvc.lib.algorithm.generators.hifigan import SineGenerator
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock
from rvc.lib.algorithm.commons import init_weights


class SourceModuleHnNSF(torch.nn.Module):
Expand Down Expand Up @@ -185,7 +186,7 @@ def forward(

# Apply upsampling layer
if self.training and self.checkpointing:
x = checkpoint.checkpoint(ups, x, use_reentrant=False)
x = checkpoint(ups, x, use_reentrant=False)
else:
x = ups(x)

Expand All @@ -200,9 +201,7 @@ def resblock_forward(x, blocks):

# Checkpoint or regular computation for ResBlocks
if self.training and self.checkpointing:
x = checkpoint.checkpoint(
resblock_forward, x, blocks, use_reentrant=False
)
x = checkpoint(resblock_forward, x, blocks, use_reentrant=False)
else:
x = resblock_forward(x, blocks)

Expand Down
8 changes: 4 additions & 4 deletions rvc/lib/algorithm/generators/refinegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import torch.utils.checkpoint as checkpoint
from torch.utils.checkpoint import checkpoint

from rvc.lib.algorithm.commons import get_padding

Expand Down Expand Up @@ -444,7 +444,7 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
x = torch.nn.functional.leaky_relu(x, self.leaky_relu_slope, inplace=True)
downs.append(x)
if self.training and self.checkpointing:
x = checkpoint.checkpoint(block, x, use_reentrant=False)
x = checkpoint(block, x, use_reentrant=False)
else:
x = block(x)

Expand All @@ -464,9 +464,9 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
x = torch.nn.functional.leaky_relu(x, self.leaky_relu_slope, inplace=True)

if self.training and self.checkpointing:
x = checkpoint.checkpoint(ups, x, use_reentrant=False)
x = checkpoint(ups, x, use_reentrant=False)
x = torch.cat([x, down], dim=1)
x = checkpoint.checkpoint(res, x, use_reentrant=False)
x = checkpoint(res, x, use_reentrant=False)
else:
x = ups(x)
x = torch.cat([x, down], dim=1)
Expand Down

0 comments on commit 0f98c61

Please sign in to comment.