From 9f881ed57fe4dfc1a97c86ba17bc9675a74ebf11 Mon Sep 17 00:00:00 2001 From: syoka Date: Sat, 7 Dec 2024 15:50:04 +0800 Subject: [PATCH] #fix 1)enable mps device 2) fix bug in reference audio scenario (#714) * #fix 1)enable mps device 2)fix bug in reference audio scenario * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: xiaokai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tools/inference_engine/reference_loader.py | 2 +- tools/run_webui.py | 5 +++++ tools/server/model_manager.py | 5 +++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tools/inference_engine/reference_loader.py b/tools/inference_engine/reference_loader.py index 91232eef..4b560393 100644 --- a/tools/inference_engine/reference_loader.py +++ b/tools/inference_engine/reference_loader.py @@ -50,7 +50,7 @@ def load_by_id( # If the references are not already loaded, encode them prompt_tokens = [ self.encode_reference( - decoder_model=self.decoder_model, + # decoder_model=self.decoder_model, reference_audio=audio_to_bytes(str(ref_audio)), enable_reference_audio=True, ) diff --git a/tools/run_webui.py b/tools/run_webui.py index 6b0ab490..5844b72c 100644 --- a/tools/run_webui.py +++ b/tools/run_webui.py @@ -45,6 +45,11 @@ def parse_args(): args = parse_args() args.precision = torch.half if args.half else torch.bfloat16 + # Check if MPS is available + if torch.backends.mps.is_available(): + args.device = "mps" + logger.info("mps is available, running on mps.") + # Check if CUDA is available if not torch.cuda.is_available(): logger.info("CUDA is not available, running on CPU.") diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py index 549ad8d4..c3f0a896 100644 --- a/tools/server/model_manager.py +++ b/tools/server/model_manager.py @@ -34,6 +34,11 @@ def __init__( self.precision = torch.half if half else torch.bfloat16 + # Check if MPS is available + if torch.backends.mps.is_available(): + self.device = "mps" + logger.info("mps is available, running on mps.") + # Check if CUDA is available if not torch.cuda.is_available(): self.device = "cpu"