Skip to content

Commit

Permalink
test: call cli tests via main functions to get test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Dec 14, 2024
1 parent cac5e64 commit 7fe5e5d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
26 changes: 14 additions & 12 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import sys
from argparse import RawTextHelpFormatter
from typing import Optional

# pylint: disable=redefined-outer-name, unused-argument
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
Expand Down Expand Up @@ -134,7 +135,7 @@
"""


def parse_args() -> argparse.Namespace:
def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
"""Parse arguments."""
parser = argparse.ArgumentParser(
description=description.replace(" ```\n", ""),
Expand Down Expand Up @@ -290,7 +291,7 @@ def parse_args() -> argparse.Namespace:
help="Voice dir for tortoise model",
)

args = parser.parse_args()
args = parser.parse_args(arg_list)

# print the description if either text or list_models is not set
check_args = [
Expand All @@ -309,10 +310,10 @@ def parse_args() -> argparse.Namespace:
return args


def main() -> None:
def main(arg_list: Optional[list[str]] = None) -> None:
"""Entry point for `tts` command line interface."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
args = parse_args()
args = parse_args(arg_list)

pipe_out = sys.stdout if args.pipe_out else None

Expand All @@ -339,18 +340,18 @@ def main() -> None:
# 1) List pre-trained TTS models
if args.list_models:
manager.list_models()
sys.exit()
sys.exit(0)

# 2) Info about pre-trained TTS models (without loading a model)
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
sys.exit()
sys.exit(0)

if args.model_info_by_name:
model_query_full_name = args.model_info_by_name
manager.model_info_by_full_name(model_query_full_name)
sys.exit()
sys.exit(0)

# 3) Load a model for further info or TTS/VC
device = args.device
Expand All @@ -376,31 +377,31 @@ def main() -> None:
if args.list_speaker_idxs:
if not api.is_multi_speaker:
logger.info("Model only has a single speaker.")
return
sys.exit(0)
logger.info(
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
)
logger.info(api.speakers)
return
sys.exit(0)

# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
if not api.is_multi_lingual:
logger.info("Monolingual model.")
return
sys.exit(0)
logger.info(
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
logger.info(api.languages)
return
sys.exit(0)

# check the arguments against a multi-speaker model.
if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav):
logger.error(
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
)
return
sys.exit(1)

# RUN THE SYNTHESIS
if args.text:
Expand Down Expand Up @@ -429,6 +430,7 @@ def main() -> None:
pipe_out=pipe_out,
)
logger.info("Saved VC output to %s", args.out_path)
sys.exit(0)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from typing import Callable, Optional

import pytest
from trainer.generic_utils import get_cuda

from TTS.config import BaseDatasetConfig
Expand Down Expand Up @@ -44,6 +46,12 @@ def run_cli(command):
assert exit_status == 0, f" [!] command `{command}` failed."


def run_main(main_func: Callable, args: Optional[list[str]] = None, expected_code: int = 0):
with pytest.raises(SystemExit) as exc_info:
main_func(args)
assert exc_info.value.code == expected_code


def get_test_data_config():
return BaseDatasetConfig(formatter="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")

Expand Down
23 changes: 11 additions & 12 deletions tests/inference_tests/test_synthesize.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from tests import run_cli
from tests import run_main
from TTS.bin.synthesize import main


def test_synthesize(tmp_path):
"""Test synthesize.py with diffent arguments."""
output_path = tmp_path / "output.wav"
run_cli("tts --list_models")
output_path = str(tmp_path / "output.wav")

run_main(main, ["--list_models"])

# single speaker model
run_cli(f'tts --text "This is an example." --out_path "{output_path}"')
run_cli(
"tts --model_name tts_models/en/ljspeech/glow-tts " f'--text "This is an example." --out_path "{output_path}"'
)
run_cli(
"tts --model_name tts_models/en/ljspeech/glow-tts "
"--vocoder_name vocoder_models/en/ljspeech/multiband-melgan "
f'--text "This is an example." --out_path "{output_path}"'
)
args = ["--text", "This is an example.", "--out_path", output_path]
run_main(main, args)

args = [*args, "--model_name", "tts_models/en/ljspeech/glow-tts"]
run_main(main, args)
run_main(main, [*args, "--vocoder_name", "vocoder_models/en/ljspeech/multiband-melgan"])

0 comments on commit 7fe5e5d

Please sign in to comment.