Skip to content

Commit

Permalink
unify resblocks
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf authored Dec 7, 2024
1 parent a7a8c40 commit ee3a736
Showing 1 changed file with 58 additions and 28 deletions.
86 changes: 58 additions & 28 deletions rvc/lib/algorithm/residuals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Tuple
import torch
from itertools import chain
from typing import Optional, Tuple
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm

Expand All @@ -26,41 +27,70 @@ def apply_mask(tensor, mask):
return tensor * mask if mask is not None else tensor


class ResBlockBase(torch.nn.Module):
def __init__(self, channels: int, kernel_size: int, dilations: Tuple[int]):
super(ResBlockBase, self).__init__()
self.convs1 = torch.nn.ModuleList(
class ResBlock(torch.nn.Module):
"""
A residual block module that applies a series of 1D convolutional layers with residual connections.
"""

def __init__(
self, channels: int, kernel_size: int = 3, dilations: Tuple[int] = (1, 3, 5)
):
"""
Initializes the ResBlock.
Args:
channels (int): Number of input and output channels for the convolution layers.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
dilations (Tuple[int]): Tuple of dilation rates for the convolution layers in the first set.
"""
super().__init__()
# Create convolutional layers with specified dilations and initialize weights
self.convs1 = self._create_convs(channels, kernel_size, dilations)
self.convs2 = self._create_convs(channels, kernel_size, [1] * len(dilations))

@staticmethod
def _create_convs(
channels: int, kernel_size: int, dilations: Tuple[int]
):
"""
Creates a list of 1D convolutional layers with specified dilations.
Args:
channels (int): Number of input and output channels for the convolution layers.
kernel_size (int): Size of the convolution kernel.
dilations (Tuple[int]): Tuple of dilation rates for each convolution layer.
"""
layers = torch.nn.ModuleList(
[create_conv1d_layer(channels, kernel_size, d) for d in dilations]
)
self.convs1.apply(init_weights)
layers.apply(init_weights)
return layers

self.convs2 = torch.nn.ModuleList(
[create_conv1d_layer(channels, kernel_size, 1) for _ in dilations]
)
self.convs2.apply(init_weights)

def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
xt = apply_mask(xt, x_mask)
xt = torch.nn.functional.leaky_relu(c1(xt), LRELU_SLOPE)
xt = apply_mask(xt, x_mask)
xt = c2(xt)
x = xt + x
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
"""Forward pass.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, sequence_length).
x_mask (torch.Tensor, optional): Optional mask to apply to the input and output tensors.
"""
for conv1, conv2 in zip(self.convs1, self.convs2):
x_residual = x
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
x = apply_mask(x, x_mask)
x = torch.nn.functional.leaky_relu(conv1(x), LRELU_SLOPE)
x = apply_mask(x, x_mask)
x = conv2(x)
x = x + x_residual
return apply_mask(x, x_mask)

def remove_weight_norm(self):
for conv in self.convs1 + self.convs2:
"""
Removes weight normalization from all convolutional layers in the block.
"""
for conv in chain(self.convs1, self.convs2):
remove_weight_norm(conv)


class ResBlock(ResBlockBase):
def __init__(
self, channels: int, kernel_size: int = 3, dilation: Tuple[int] = (1, 3, 5)
):
super(ResBlock, self).__init__(channels, kernel_size, dilation)


class Flip(torch.nn.Module):
"""Flip module for flow-based models.
Expand Down Expand Up @@ -115,7 +145,7 @@ def __init__(
self.gin_channels = gin_channels

self.flows = torch.nn.ModuleList()
for i in range(n_flows):
for _ in range(n_flows):
self.flows.append(
ResidualCouplingLayer(
channels,
Expand Down

0 comments on commit ee3a736

Please sign in to comment.