Skip to content

Commit

Permalink
refactor(xtts): use tortoise conditioning encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 22, 2024
1 parent add5abe commit fa12a7b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 46 deletions.
9 changes: 2 additions & 7 deletions TTS/tts/layers/tortoise/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def __init__(
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
mean=False,
):
super().__init__()
attn = []
Expand All @@ -185,15 +184,11 @@ def __init__(
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.mean = mean

def forward(self, x):
h = self.init(x)
h = self.attn(h)
if self.mean:
return h.mean(dim=2)
else:
return h[:, :, 0]
return h


class LearnedPositionEmbeddings(nn.Module):
Expand Down Expand Up @@ -473,7 +468,7 @@ def get_conditioning(self, speech_conditioning_input):
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0])
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
Expand Down
16 changes: 2 additions & 14 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from transformers import GPT2Config

from TTS.tts.layers.tortoise.autoregressive import (
ConditioningEncoder,
LearnedPositionEmbeddings,
_prepare_attention_mask_for_generation,
build_hf_gpt_transformer,
)
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler


Expand Down Expand Up @@ -235,19 +235,6 @@ def get_logits(
else:
return first_logits

def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds

def get_prompts(self, prompt_codes):
"""
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
Expand Down Expand Up @@ -286,6 +273,7 @@ def get_style_emb(self, cond_input, return_latent=False):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
output: (b, 1024, 32)
"""
conds = None
if not return_latent:
Expand Down
25 changes: 0 additions & 25 deletions TTS/tts/layers/xtts/latent_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,3 @@ def forward(self, x, mask=None, qk_bias=0):
h = self.proj_out(h)
xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial)


class ConditioningEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x)
h = self.attn(h)
return h

0 comments on commit fa12a7b

Please sign in to comment.