Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.25.1 #201

Merged
merged 16 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 103 additions & 66 deletions TTS/api.py

Large diffs are not rendered by default.

130 changes: 40 additions & 90 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from argparse import RawTextHelpFormatter

# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path

from TTS.utils.generic_utils import ConsoleFormatter, setup_logger

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -253,11 +251,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
)
# aux args
parser.add_argument(
"--save_spectogram",
action="store_true",
help="Save raw spectogram for further (vocoder) processing in out_path.",
)
parser.add_argument(
"--reference_wav",
type=str,
Expand Down Expand Up @@ -317,20 +310,20 @@ def parse_args() -> argparse.Namespace:
return args


def main():
def main() -> None:
"""Entry point for `tts` command line interface."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
args = parse_args()

pipe_out = sys.stdout if args.pipe_out else None

with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
# Late-import to make things load faster
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer

# load model manager
path = Path(__file__).parent / "../.models.json"
manager = ModelManager(path, progress_bar=args.progress_bar)
manager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=args.progress_bar)

tts_path = None
tts_config_path = None
Expand All @@ -344,12 +337,12 @@ def main():
vc_config_path = None
model_dir = None

# CASE1 #list : list pre-trained TTS models
# 1) List pre-trained TTS models
if args.list_models:
manager.list_models()
sys.exit()

# CASE2 #info : model info for pre-trained TTS models
# 2) Info about pre-trained TTS models (without loading a model)
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
Expand All @@ -360,91 +353,50 @@ def main():
manager.model_info_by_full_name(model_query_full_name)
sys.exit()

# CASE3: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)
# tts model
if model_item["model_type"] == "tts_models":
tts_path = model_path
tts_config_path = config_path
if args.vocoder_name is None and "default_vocoder" in model_item:
args.vocoder_name = model_item["default_vocoder"]

# voice conversion model
if model_item["model_type"] == "voice_conversion_models":
vc_path = model_path
vc_config_path = config_path

# tts model with multiple files to be loaded from the directory path
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
model_dir = model_path
tts_path = None
tts_config_path = None
args.vocoder_name = None

# load vocoder
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)

# CASE4: set custom model paths
if args.model_path is not None:
tts_path = args.model_path
tts_config_path = args.config_path
speakers_file_path = args.speakers_file_path
language_ids_file_path = args.language_ids_file_path

if args.vocoder_path is not None:
vocoder_path = args.vocoder_path
vocoder_config_path = args.vocoder_config_path

if args.encoder_path is not None:
encoder_path = args.encoder_path
encoder_config_path = args.encoder_config_path

# 3) Load a model for further info or TTS/VC
device = args.device
if args.use_cuda:
device = "cuda"

# load models
synthesizer = Synthesizer(
tts_checkpoint=tts_path,
tts_config_path=tts_config_path,
tts_speakers_file=speakers_file_path,
tts_languages_file=language_ids_file_path,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint=encoder_path,
encoder_config=encoder_config_path,
vc_checkpoint=vc_path,
vc_config=vc_config_path,
model_dir=model_dir,
voice_dir=args.voice_dir,
# A local model will take precedence if specified via modeL_path
model_name = args.model_name if args.model_path is None else None
api = TTS(
model_name=model_name,
model_path=args.model_path,
config_path=args.config_path,
vocoder_name=args.vocoder_name,
vocoder_path=args.vocoder_path,
vocoder_config_path=args.vocoder_config_path,
encoder_path=args.encoder_path,
encoder_config_path=args.encoder_config_path,
speakers_file_path=args.speakers_file_path,
language_ids_file_path=args.language_ids_file_path,
progress_bar=args.progress_bar,
).to(device)

# query speaker ids of a multi-speaker model.
if args.list_speaker_idxs:
if synthesizer.tts_model.speaker_manager is None:
if not api.is_multi_speaker:
logger.info("Model only has a single speaker.")
return
logger.info(
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
)
logger.info(list(synthesizer.tts_model.speaker_manager.name_to_id.keys()))
logger.info(api.speakers)
return

# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
if synthesizer.tts_model.language_manager is None:
if not api.is_multi_lingual:
logger.info("Monolingual model.")
return
logger.info(
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
logger.info(synthesizer.tts_model.language_manager.name_to_id)
logger.info(api.languages)
return

# check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav):
logger.error(
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
Expand All @@ -455,31 +407,29 @@ def main():
if args.text:
logger.info("Text: %s", args.text)

# kick it
if tts_path is not None:
wav = synthesizer.tts(
args.text,
speaker_name=args.speaker_idx,
language_name=args.language_idx,
if args.text is not None:
api.tts_to_file(
text=args.text,
speaker=args.speaker_idx,
language=args.language_idx,
speaker_wav=args.speaker_wav,
pipe_out=pipe_out,
file_path=args.out_path,
reference_wav=args.reference_wav,
style_wav=args.capacitron_style_wav,
style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx,
voice_dir=args.voice_dir,
)
elif vc_path is not None:
wav = synthesizer.voice_conversion(
logger.info("Saved TTS output to %s", args.out_path)
elif args.source_wav is not None and args.target_wav is not None:
api.voice_conversion_to_file(
source_wav=args.source_wav,
target_wav=args.target_wav,
file_path=args.out_path,
pipe_out=pipe_out,
)
elif model_dir is not None:
wav = synthesizer.tts(
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
)

# save the results
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
logger.info("Saved output to %s", args.out_path)
logger.info("Saved VC output to %s", args.out_path)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion TTS/demos/xtts_ft_demo/utils/gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager


Expand Down
17 changes: 11 additions & 6 deletions TTS/tts/layers/tortoise/arch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ def forward(self, qkv, mask=None, rel_pos=None):
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)

return a.reshape(bs, -1, length)
Expand All @@ -93,7 +92,9 @@ def __init__(
channels,
num_heads=1,
num_head_channels=-1,
*,
relative_pos_embeddings=False,
tortoise_norm=False,
):
super().__init__()
self.channels = channels
Expand All @@ -108,6 +109,7 @@ def __init__(
self.qkv = nn.Conv1d(channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.tortoise_norm = tortoise_norm

self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
Expand All @@ -124,10 +126,13 @@ def __init__(
def forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
x_norm = self.norm(x)
qkv = self.qkv(x_norm)
h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
if self.tortoise_norm:
return (x + h).reshape(b, c, *spatial)
return (x_norm + h).reshape(b, c, *spatial)


class Upsample(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/tortoise/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,14 @@ def __init__(
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
*,
tortoise_norm=False,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=tortoise_norm))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/tortoise/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=True))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

Expand Down
21 changes: 13 additions & 8 deletions TTS/tts/layers/tortoise/diffusion_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, model_channels, dropout, num_heads):
dims=1,
use_scale_shift_norm=True,
)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True)

def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
Expand Down Expand Up @@ -177,17 +177,17 @@ def __init__(
# transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
)
self.code_norm = normalization(model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
)
self.contextual_embedder = nn.Sequential(
nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
Expand All @@ -196,26 +196,31 @@ def __init__(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
Expand Down
Loading
Loading