Skip to content

Commit

Permalink
refactor: apply safe automatic ruff lint fixes
Browse files Browse the repository at this point in the history
Not manually checked. Generated with:
uv run ruff check tests/ TTS/ notebooks/ recipes/ --fix
  • Loading branch information
eginhard committed Jan 10, 2025
1 parent e850e06 commit 63fd577
Show file tree
Hide file tree
Showing 137 changed files with 657 additions and 785 deletions.
37 changes: 18 additions & 19 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import tempfile
import warnings
from pathlib import Path
from typing import Optional

from torch import nn

Expand All @@ -22,15 +21,15 @@ def __init__(
self,
model_name: str = "",
*,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
vocoder_name: Optional[str] = None,
vocoder_path: Optional[str] = None,
vocoder_config_path: Optional[str] = None,
encoder_path: Optional[str] = None,
encoder_config_path: Optional[str] = None,
speakers_file_path: Optional[str] = None,
language_ids_file_path: Optional[str] = None,
model_path: str | None = None,
config_path: str | None = None,
vocoder_name: str | None = None,
vocoder_path: str | None = None,
vocoder_config_path: str | None = None,
encoder_path: str | None = None,
encoder_config_path: str | None = None,
speakers_file_path: str | None = None,
language_ids_file_path: str | None = None,
progress_bar: bool = True,
gpu: bool = False,
) -> None:
Expand Down Expand Up @@ -156,8 +155,8 @@ def list_models() -> list[str]:
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()

def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[Path], Optional[Path], Optional[Path]]:
self, model_name: str, vocoder_name: str | None = None
) -> tuple[Path | None, Path | None, Path | None]:
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
Expand All @@ -174,7 +173,7 @@ def download_model_by_name(
self.vocoder_config_path = vocoder_config_path
return model_path, config_path, None

def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
def load_model_by_name(self, model_name: str, vocoder_name: str | None = None, *, gpu: bool = False) -> None:
"""Load one of the 🐸TTS models by name.
Args:
Expand All @@ -196,7 +195,7 @@ def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None:
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)

def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
def load_tts_model_by_name(self, model_name: str, vocoder_name: str | None = None, *, gpu: bool = False) -> None:
"""Load one of 🐸TTS models by name.
Args:
Expand Down Expand Up @@ -250,11 +249,11 @@ def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool

def _check_arguments(
self,
speaker: Optional[str] = None,
language: Optional[str] = None,
speaker_wav: Optional[str] = None,
emotion: Optional[str] = None,
speed: Optional[float] = None,
speaker: str | None = None,
language: str | None = None,
speaker_wav: str | None = None,
emotion: str | None = None,
speed: float | None = None,
**kwargs,
) -> None:
"""Check if the arguments are valid for the model."""
Expand Down
6 changes: 2 additions & 4 deletions TTS/bin/compute_statistics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import glob
import logging
import os
import sys
from typing import Optional

import numpy as np
from tqdm import tqdm
Expand All @@ -18,7 +16,7 @@
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger


def parse_args(arg_list: Optional[list[str]]) -> tuple[argparse.Namespace, list[str]]:
def parse_args(arg_list: list[str] | None) -> tuple[argparse.Namespace, list[str]]:
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str, help="save path (directory and filename).")
Expand All @@ -31,7 +29,7 @@ def parse_args(arg_list: Optional[list[str]]) -> tuple[argparse.Namespace, list[
return parser.parse_known_args(arg_list)


def main(arg_list: Optional[list[str]] = None):
def main(arg_list: list[str] | None = None):
"""Run preprocessing process."""
setup_logger("TTS", level=logging.INFO, stream=sys.stderr, formatter=ConsoleFormatter())
args, overrides = parse_args(arg_list)
Expand Down
5 changes: 2 additions & 3 deletions TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import sys
from pathlib import Path
from typing import Optional

import numpy as np
import torch
Expand All @@ -27,7 +26,7 @@
use_cuda = torch.cuda.is_available()


def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
Expand Down Expand Up @@ -244,7 +243,7 @@ def extract_spectrograms(
f.write(f"{data[0] / data[1]}.npy\n")


def main(arg_list: Optional[list[str]] = None) -> None:
def main(arg_list: list[str] | None = None) -> None:
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
args = parse_args(arg_list)
config = load_config(args.config_path)
Expand Down
5 changes: 2 additions & 3 deletions TTS/bin/find_unique_phonemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import multiprocessing
import sys
from argparse import RawTextHelpFormatter
from typing import Optional

from tqdm.contrib.concurrent import process_map

Expand All @@ -21,7 +20,7 @@ def compute_phonemes(item: dict) -> set[str]:
return set(ph)


def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
"""
Expand All @@ -35,7 +34,7 @@ def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
return parser.parse_args(arg_list)


def main(arg_list: Optional[list[str]] = None) -> None:
def main(arg_list: list[str] | None = None) -> None:
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
global phonemizer
args = parse_args(arg_list)
Expand Down
5 changes: 2 additions & 3 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import sys
from argparse import RawTextHelpFormatter
from typing import Optional

# pylint: disable=redefined-outer-name, unused-argument
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
Expand Down Expand Up @@ -135,7 +134,7 @@
"""


def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
"""Parse arguments."""
parser = argparse.ArgumentParser(
description=description.replace(" ```\n", ""),
Expand Down Expand Up @@ -310,7 +309,7 @@ def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
return args


def main(arg_list: Optional[list[str]] = None) -> None:
def main(arg_list: list[str] | None = None) -> None:
"""Entry point for `tts` command line interface."""
args = parse_args(arg_list)
stream = sys.stderr if args.pipe_out else sys.stdout
Expand Down
17 changes: 6 additions & 11 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import logging
import os
Expand Down Expand Up @@ -219,10 +218,8 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,

if global_step % c.print_step == 0:
print(
" | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
),
f" | > Step:{global_step} Loss:{loss.item():.5f} GradNorm:{grad_norm:.5f} "
f"StepTime:{step_time:.2f} LoaderTime:{loader_time:.2f} AvGLoaderTime:{avg_loader_time:.2f} LR:{current_lr:.6f}",
flush=True,
)

Expand All @@ -236,10 +233,8 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,

print("")
print(
">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time
),
f">>> Epoch:{epoch} AvgLoss: {tot_loss / len(data_loader):.5f} GradNorm:{grad_norm:.5f} "
f"EpochTime:{epoch_time:.2f} AvGLoaderTime:{avg_loader_time:.2f} ",
flush=True,
)
# evaluation
Expand All @@ -249,7 +244,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
" | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
f" | > Epoch:{epoch} AvgLoss: {eval_loss:.5f} ",
flush=True,
)
# save the best checkpoint
Expand Down Expand Up @@ -311,7 +306,7 @@ def main(args): # pylint: disable=redefined-outer-name
scheduler = None

num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True)
print(f"\n > Model has {num_params} parameters", flush=True)

if use_cuda:
model = model.cuda()
Expand Down
3 changes: 1 addition & 2 deletions TTS/bin/train_vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

from trainer import Trainer, TrainerArgs

Expand All @@ -18,7 +17,7 @@ class TrainVocoderArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})


def main(arg_list: Optional[list[str]] = None):
def main(arg_list: list[str] | None = None):
"""Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())

Expand Down
4 changes: 2 additions & 2 deletions TTS/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def register_config(model_name: str) -> Coqpit:
return config_class


def _process_model_name(config_dict: Dict) -> str:
def _process_model_name(config_dict: dict) -> str:
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
Args:
Expand All @@ -68,7 +68,7 @@ def _process_model_name(config_dict: Dict) -> str:
return model_name


def load_config(config_path: Union[str, os.PathLike[Any]]) -> Coqpit:
def load_config(config_path: str | os.PathLike[Any]) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config.
Expand Down
3 changes: 1 addition & 2 deletions TTS/config/shared_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import asdict, dataclass
from typing import List

from coqpit import Coqpit, check_argument
from trainer import TrainerConfig
Expand Down Expand Up @@ -227,7 +226,7 @@ class BaseDatasetConfig(Coqpit):
dataset_name: str = ""
path: str = ""
meta_file_train: str = ""
ignored_speakers: List[str] = None
ignored_speakers: list[str] = None
language: str = ""
phonemizer: str = ""
meta_file_val: str = ""
Expand Down
2 changes: 1 addition & 1 deletion TTS/demos/xtts_ft_demo/xtts_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def isatty(self):

def read_logs():
sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f:
with open(sys.stdout.log_file) as f:
return f.read()


Expand Down
9 changes: 4 additions & 5 deletions TTS/encoder/configs/base_encoder_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import asdict, dataclass, field
from typing import Dict, List

from coqpit import MISSING

Expand All @@ -12,9 +11,9 @@ class BaseEncoderConfig(BaseTrainingConfig):

model: str = None
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
datasets: list[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# model params
model_params: Dict = field(
model_params: dict = field(
default_factory=lambda: {
"model_name": "lstm",
"input_dim": 80,
Expand All @@ -25,15 +24,15 @@ class BaseEncoderConfig(BaseTrainingConfig):
}
)

audio_augmentation: Dict = field(default_factory=lambda: {})
audio_augmentation: dict = field(default_factory=lambda: {})

# training params
epochs: int = 10000
loss: str = "angleproto"
grad_clip: float = 3.0
lr: float = 0.0001
optimizer: str = "radam"
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
lr_decay: bool = False
warmup_steps: int = 4000

Expand Down
2 changes: 1 addition & 1 deletion TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)


class AugmentWAV(object):
class AugmentWAV:
def __init__(self, ap, augmentation_config):
self.ap = ap
self.use_additive_noise = False
Expand Down
3 changes: 1 addition & 2 deletions TTS/encoder/utils/prepare_voxceleb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
# All rights reserved.
#
Expand Down Expand Up @@ -194,7 +193,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
for wav_file in files:
writer.writerow(wav_file)
logger.info("Successfully generated csv file {}".format(csv_file_path))
logger.info(f"Successfully generated csv file {csv_file_path}")


def processor(directory, subset, force_process):
Expand Down
4 changes: 2 additions & 2 deletions TTS/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import abstractmethod
from typing import Any, Union
from typing import Any

import torch
from coqpit import Coqpit
Expand Down Expand Up @@ -48,7 +48,7 @@ def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict
def load_checkpoint(
self,
config: Coqpit,
checkpoint_path: Union[str, os.PathLike[Any]],
checkpoint_path: str | os.PathLike[Any],
eval: bool = False,
strict: bool = True,
cache: bool = False,
Expand Down
3 changes: 1 addition & 2 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import sys
from pathlib import Path
from threading import Lock
from typing import Union
from urllib.parse import parse_qs

try:
Expand Down Expand Up @@ -134,7 +133,7 @@ def create_argparser() -> argparse.ArgumentParser:
app = Flask(__name__)


def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
def style_wav_uri_to_dict(style_wav: str) -> str | dict:
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
or a dict (gst tokens/values to be use for styling)
Expand Down
Loading

0 comments on commit 63fd577

Please sign in to comment.