Skip to content

Commit

Permalink
refactor(synthesizer): set sample rate in loading methods
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Dec 2, 2024
1 parent 7d0416f commit 3539e65
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,26 +95,20 @@ def __init__(

if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]

if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]

if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]

if model_dir:
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
elif "openvoice" in model_dir:
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]

@staticmethod
def _get_segmenter(lang: str):
Expand Down Expand Up @@ -143,6 +137,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
Expand All @@ -157,6 +152,7 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if use_cuda:
self.tts_model.cuda()

Expand All @@ -170,6 +166,7 @@ def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None:
self.vc_model = OpenVoice.init_from_config(self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
self.vc_config = self.vc_model.config
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if use_cuda:
self.vc_model.cuda()

Expand All @@ -180,6 +177,7 @@ def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
Expand All @@ -201,6 +199,7 @@ def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -
"""
# pylint: disable=global-statement
self.tts_config = load_config(tts_config_path)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.")

Expand Down Expand Up @@ -238,6 +237,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
Expand Down

0 comments on commit 3539e65

Please sign in to comment.