From 3539e65d8e9d31d44c57b2c4a84ae1f372ade611 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 2 Dec 2024 22:50:33 +0100 Subject: [PATCH] refactor(synthesizer): set sample rate in loading methods --- TTS/utils/synthesizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 73f596d167..a9b9feffc1 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -95,26 +95,20 @@ def __init__( if tts_checkpoint: self._load_tts(tts_checkpoint, tts_config_path, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] if vocoder_checkpoint: self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) - self.output_sample_rate = self.vocoder_config.audio["sample_rate"] if vc_checkpoint and model_dir is None: self._load_vc(vc_checkpoint, vc_config, use_cuda) - self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if model_dir: if "fairseq" in model_dir: self._load_fairseq_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] elif "openvoice" in model_dir: self._load_openvoice_from_dir(Path(model_dir), use_cuda) - self.output_sample_rate = self.vc_config.audio["output_sample_rate"] else: self._load_tts_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @staticmethod def _get_segmenter(lang: str): @@ -143,6 +137,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N """ # pylint: disable=global-statement self.vc_config = load_config(vc_config_path) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] self.vc_model = setup_vc_model(config=self.vc_config) self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) if use_cuda: @@ -157,6 +152,7 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: self.tts_model = Vits.init_from_config(self.tts_config) self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) self.tts_config = self.tts_model.config + self.output_sample_rate = self.tts_config.audio["sample_rate"] if use_cuda: self.tts_model.cuda() @@ -170,6 +166,7 @@ def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None: self.vc_model = OpenVoice.init_from_config(self.vc_config) self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True) self.vc_config = self.vc_model.config + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if use_cuda: self.vc_model.cuda() @@ -180,6 +177,7 @@ def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """ config = load_config(os.path.join(model_dir, "config.json")) self.tts_config = config + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] self.tts_model = setup_tts_model(config) self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) if use_cuda: @@ -201,6 +199,7 @@ def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) - """ # pylint: disable=global-statement self.tts_config = load_config(tts_config_path) + self.output_sample_rate = self.tts_config.audio["sample_rate"] if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: raise ValueError("Phonemizer is not defined in the TTS config.") @@ -238,6 +237,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N use_cuda (bool): enable/disable CUDA use. """ self.vocoder_config = load_config(model_config) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)