Skip to content

Commit

Permalink
Add check for missing cuda and ipex params
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Jan 31, 2025
1 parent e193a92 commit 0397469
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ def set_cudnn_params():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
if hasattr(torch.backends.cuda, "allow_fp16_bf16_reduction_math_sdp"): # only valid for torch >= 2.5
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except Exception as e:
log.warning(f'Torch matmul: {e}')
if torch.backends.cudnn.is_available():
Expand All @@ -395,7 +396,8 @@ def set_cudnn_params():
def override_ipex_math():
if backend == "ipex":
try:
torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32)
if hasattr(torch.xpu, "set_fp32_math_mode"): # not available with pure torch+xpu, requires ipex
torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32)
except Exception as e:
log.warning(f'Torch ipex: {e}')

Expand Down

0 comments on commit 0397469

Please sign in to comment.