diff --git a/config.py b/config.py index 48187f530..a59f738a3 100644 --- a/config.py +++ b/config.py @@ -3,7 +3,7 @@ from multiprocessing import cpu_count -def config_file_change_fp32(): +def use_fp32_config(): for config_file in ["32k.json", "40k.json", "48k.json"]: with open(f"configs/{config_file}", "r") as f: strr = f.read().replace("true", "false") @@ -58,6 +58,17 @@ def arg_parse() -> tuple: cmd_opts.noparallel, cmd_opts.noautoopen, ) + + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. + # check `getattr` and try it for compatibility + @staticmethod + def has_mps() -> bool: + if not torch.backends.mps.is_available(): return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False def device_config(self) -> tuple: if torch.cuda.is_available(): @@ -70,9 +81,9 @@ def device_config(self) -> tuple: or "1070" in self.gpu_name or "1080" in self.gpu_name ): - print("16系/10系显卡和P40强制单精度") + print("16|10|P40 series, force to fp32") self.is_half = False - config_file_change_fp32() + use_fp32_config() else: self.gpu_name = None self.gpu_mem = int( @@ -87,16 +98,16 @@ def device_config(self) -> tuple: strr = f.read().replace("3.7", "3.0") with open("trainset_preprocess_pipeline_print.py", "w") as f: f.write(strr) - elif torch.backends.mps.is_available(): - print("没有发现支持的N卡, 使用MPS进行推理") + elif self.has_mps(): + print("No supported Nvidia GPU, use MPS instead") self.device = "mps" self.is_half = False - config_file_change_fp32() + use_fp32_config() else: - print("没有发现支持的N卡, 使用CPU进行推理") + print("No supported Nvidia GPU, use CPU instead") self.device = "cpu" self.is_half = False - config_file_change_fp32() + use_fp32_config() if self.n_cpu == 0: self.n_cpu = cpu_count()