Skip to content

Commit

Permalink
[vits] Fix bug of remove weight norm
Browse files Browse the repository at this point in the history
  • Loading branch information
ShengqiangLi authored and ShengqiangLi committed Jan 19, 2024
1 parent 54b7b3a commit 4122b84
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions wetts/vits/model/modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.nn.utils import weight_norm, remove_weight_norm

from utils import commons

Expand Down Expand Up @@ -89,11 +88,11 @@ def forward(self, x, x_mask, g=None, **kwargs):

def remove_weight_norm(self):
if self.gin_channels != 0:
remove_parametrizations(self.cond_layer, "weight")
remove_weight_norm(self.cond_layer, "weight")
for l in self.in_layers:
remove_parametrizations(l, "weight")
remove_weight_norm(l, "weight")
for l in self.res_skip_layers:
remove_parametrizations(l, "weight")
remove_weight_norm(l, "weight")


class Flip(nn.Module):
Expand Down

0 comments on commit 4122b84

Please sign in to comment.