Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch remaining CLI tests to Python, separate integration tests #276

Draft
wants to merge 6 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.10", "3.12"]
subset: ["test_tts", "test_tts2", "test_vocoder", "test_xtts"]
shard: [0, 1, 2, 3, 4]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Install Espeak
if: contains(fromJSON('["test_tts", "test_tts2", "test_xtts"]'), matrix.subset)
run: |
sudo apt-get update
sudo apt-get install espeak espeak-ng
Expand All @@ -87,18 +86,21 @@ jobs:
if [[ -n "${{ github.event.inputs.coqpit_branch }}" ]]; then
uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }}
fi
- name: Integration tests
- name: Integration tests for shard ${{ matrix.shard }}
run: |
uv run pytest tests/integration --collect-only --quiet | grep "::" > integration_tests.txt
total_shards=5
shard_tests=$(awk "NR % $total_shards == ${{ matrix.shard }}" integration_tests.txt)
resolution=highest
if [ "${{ matrix.python-version }}" == "3.10" ]; then
resolution=lowest-direct
fi
uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
uv run --resolution=$resolution --extra languages pytest -x -v --durations=0 $shard_tests
- name: Upload coverage data
uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
name: coverage-data-integration-${{ matrix.shard }}-${{ matrix.python-version }}
path: .coverage.*
if-no-files-found: ignore
zoo:
Expand Down
6 changes: 0 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ test_vocoder: ## run vocoder tests.
test_tts: ## run tts tests.
coverage run -m pytest -x -v --durations=0 tests/tts_tests

test_tts2: ## run tts tests.
coverage run -m pytest -x -v --durations=0 tests/tts_tests2

test_xtts:
coverage run -m pytest -x -v --durations=0 tests/xtts_tests

test_aux: ## run aux tests.
coverage run -m pytest -x -v --durations=0 tests/aux_tests

Expand Down
152 changes: 122 additions & 30 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
#!/usr/bin/env python3

# TODO: use Trainer

import logging
import os
import sys
import time
import traceback
import warnings
from dataclasses import dataclass, field

import torch
from torch.utils.data import DataLoader
from trainer.generic_utils import count_parameters, remove_experiment_folder
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer import TrainerArgs, TrainerConfig
from trainer.generic_utils import count_parameters, get_experiment_folder_path, get_git_branch
from trainer.io import copy_model_files, get_last_checkpoint, save_best_model, save_checkpoint
from trainer.logging import BaseDashboardLogger, ConsoleLogger, logger_factory
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer

from TTS.config import load_config, register_config
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.samplers import PerfectBatchSampler
Expand All @@ -33,7 +39,77 @@
print(" > Number of GPUs: ", num_gpus)


def setup_loader(ap: AudioProcessor, is_val: bool = False):
@dataclass
class TrainArgs(TrainerArgs):
config_path: str | None = field(default=None, metadata={"help": "Path to the config file."})


def process_args(
args, config: BaseEncoderConfig | None = None
) -> tuple[BaseEncoderConfig, str, str, ConsoleLogger, BaseDashboardLogger | None]:
"""Process parsed comand line arguments and initialize the config if not provided.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (Coqpit): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO:
- Interactive config definition.
"""
coqpit_overrides = None
if isinstance(args, tuple):
args, coqpit_overrides = args
if args.continue_path:
# continue a previous training from its output folder
experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:
# init from a file
config = load_config(args.config_path)
else:
# init from console args
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel

config_base = BaseTrainingConfig()
config_base.parse_known_args(coqpit_overrides)
config = register_config(config_base.model)()
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
experiment_path = args.continue_path
if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training
dashboard_logger = None
if args.rank == 0:
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if config.has("characters") and config.characters is None:
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = logger_factory(config, experiment_path)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger


def setup_loader(c: TrainerConfig, ap: AudioProcessor, is_val: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch

Expand Down Expand Up @@ -83,7 +159,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False):
return loader, classes, dataset.get_map_classid_to_classname()


def evaluation(model, criterion, data_loader, global_step):
def evaluation(c: BaseEncoderConfig, model, criterion, data_loader, global_step, dashboard_logger: BaseDashboardLogger):
eval_loss = 0
for _, data in enumerate(data_loader):
with torch.inference_mode():
Expand Down Expand Up @@ -127,7 +203,17 @@ def evaluation(model, criterion, data_loader, global_step):
return eval_avg_loss


def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
def train(
c: BaseEncoderConfig,
model,
optimizer,
scheduler,
criterion,
data_loader,
eval_data_loader,
global_step,
dashboard_logger: BaseDashboardLogger,
):
model.train()
best_loss = {"train_loss": None, "eval_loss": float("inf")}
avg_loader_time = 0
Expand Down Expand Up @@ -226,7 +312,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.save_step == 0:
# save model
save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
c, model, optimizer, None, global_step, epoch, c.output_log_path, criterion=criterion.state_dict()
)

end_time = time.time()
Expand All @@ -240,7 +326,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
# evaluation
if c.run_eval:
model.eval()
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
eval_loss = evaluation(c, model, criterion, eval_data_loader, global_step, dashboard_logger)
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
Expand All @@ -257,15 +343,21 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
None,
global_step,
epoch,
OUT_PATH,
c.output_log_path,
criterion=criterion.state_dict(),
)
model.train()

return best_loss, global_step


def main(args): # pylint: disable=redefined-outer-name
def main(arg_list: list[str] | None = None):
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())

train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="")
args, overrides = parser.parse_known_args(arg_list)
c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args((args, overrides))
# pylint: disable=global-variable-undefined
global meta_data_train
global meta_data_eval
Expand All @@ -279,9 +371,9 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)

train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(c, ap, is_val=False)
if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True)
eval_data_loader, _, _ = setup_loader(c, ap, is_val=True)
else:
eval_data_loader = None

Expand Down Expand Up @@ -313,23 +405,23 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda()

global_step = args.restore_step
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
_, global_step = train(
c, model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step, dashboard_logger
)
sys.exit(0)


if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())

args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()

try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)
main()
# try:
# main()
# except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)
# try:
# sys.exit(0)
# except SystemExit:
# os._exit(0) # pylint: disable=protected-access
# except Exception: # pylint: disable=broad-except
# remove_experiment_folder(OUT_PATH)
# traceback.print_exc()
# sys.exit(1)
7 changes: 4 additions & 3 deletions TTS/bin/train_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ class TrainTTSArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})


def main():
def main(arg_list: list[str] | None = None):
"""Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())

# init trainer args
train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="")

# override trainer args from comman-line args
args, config_overrides = parser.parse_known_args()
# override trainer args from command-line args
args, config_overrides = parser.parse_known_args(arg_list)
train_args.parse_args(args)

# load config.json and register
Expand Down Expand Up @@ -70,6 +70,7 @@ def main():
parse_command_line_args=False,
)
trainer.fit()
sys.exit(0)


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scipy import signal

from TTS.encoder.models.base_encoder import BaseEncoder
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder

Expand Down Expand Up @@ -120,7 +121,7 @@ def apply_one(self, audio):
return self.additive_noise(noise_type, audio)


def setup_encoder_model(config: "Coqpit"):
def setup_encoder_model(config: "Coqpit") -> BaseEncoder:
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
config.model_params["input_dim"],
Expand All @@ -138,4 +139,7 @@ def setup_encoder_model(config: "Coqpit"):
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
else:
msg = f"Model not supported: {config.model_params['model_name']}"
raise ValueError(msg)
return model
Loading
Loading