Skip to content

Commit

Permalink
improve encoders arch
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf authored Dec 29, 2024
1 parent 7ae48d9 commit 1f47fbb
Showing 1 changed file with 35 additions and 51 deletions.
86 changes: 35 additions & 51 deletions rvc/lib/algorithm/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,41 @@ def __init__(
window_size: int = 10,
):
super().__init__()

self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size

self.drop = torch.nn.Dropout(p_dropout)
self.attn_layers = torch.nn.ModuleList()
self.norm_layers_1 = torch.nn.ModuleList()
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(

self.attn_layers = torch.nn.ModuleList(
[
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
for _ in range(n_layers)
]
)
self.norm_layers_1 = torch.nn.ModuleList(
[LayerNorm(hidden_channels) for _ in range(n_layers)]
)
self.ffn_layers = torch.nn.ModuleList(
[
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
for _ in range(n_layers)
]
)
self.norm_layers_2 = torch.nn.ModuleList(
[LayerNorm(hidden_channels) for _ in range(n_layers)]
)

def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
Expand Down Expand Up @@ -112,39 +113,30 @@ def __init__(
embedding_dim: int,
f0: bool = True,
):
super(TextEncoder, self).__init__()
self.out_channels = out_channels
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = float(p_dropout)
self.out_channels = out_channels
self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
if f0:
self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
self.emb_pitch = torch.nn.Embedding(256, hidden_channels) if f0 else None

self.encoder = Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)

def forward(
self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor
):
if pitch is None:
x = self.emb_phone(phone)
else:
x = self.emb_phone(phone) + self.emb_pitch(pitch)
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
x = self.emb_phone(phone)
if pitch is not None and self.emb_pitch:
x += self.emb_pitch(pitch)

x *= math.sqrt(self.hidden_channels)
x = self.lrelu(x)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
x = x.transpose(1, -1) # [B, H, T]

x_mask = sequence_mask(lengths, x.size(2)).unsqueeze(1).to(x.dtype)
x = self.encoder(x, x_mask)
stats = self.proj(x) * x_mask

Expand Down Expand Up @@ -176,15 +168,8 @@ def __init__(
n_layers: int,
gin_channels: int = 0,
):
super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels

self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WaveNet(
hidden_channels,
Expand All @@ -198,17 +183,16 @@ def __init__(
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
):
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)

x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)

stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)

logs_exp = torch.exp(logs)
z = m + torch.randn_like(m) * logs_exp
z = z * x_mask
z = m + torch.randn_like(m) * torch.exp(logs)
z *= x_mask

return z, m, logs, x_mask

Expand Down

0 comments on commit 1f47fbb

Please sign in to comment.