diff --git a/modules/devices.py b/modules/devices.py index b49745bd3..f633400a4 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -24,6 +24,8 @@ def extract_device_id(args, name): # pylint: disable=redefined-outer-name def get_cuda_device_string(): if shared.cmd_opts.use_ipex: + if shared.cmd_opts.device_id is not None: + return f"xpu:{shared.cmd_opts.device_id}" return "xpu" else: if shared.cmd_opts.device_id is not None: @@ -33,7 +35,7 @@ def get_cuda_device_string(): def get_optimal_device_name(): if shared.cmd_opts.use_ipex: - return "xpu" + return get_cuda_device_string() elif cuda_ok and not shared.cmd_opts.use_directml: return get_cuda_device_string() if has_mps(): @@ -66,7 +68,7 @@ def torch_gc(force=False): collected = gc.collect() if shared.cmd_opts.use_ipex: try: - with torch.xpu.device("xpu"): + with torch.xpu.device(get_cuda_device_string()): torch.xpu.empty_cache() except: pass @@ -143,7 +145,10 @@ def set_cuda_params(): args = cmd_args.parser.parse_args() if args.use_ipex: - cpu = torch.device("xpu") #Use XPU instead of CPU. %20 Perf improvement on weak CPUs. + if args.device_id is not None: + cpu = torch.device(f"xpu:{args.device_id}") #Use XPU instead of CPU. %20 Perf improvement on weak CPUs. + else: + cpu = torch.device("xpu") else: cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None