forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf32.py
89 lines (72 loc) · 3.07 KB
/
tf32.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import os
import warnings
__all__ = ["has_ampere_or_later", "detect_default_tf32"]
@functools.lru_cache(None)
def has_ampere_or_later() -> bool:
"""
Check if there is any Ampere and later GPU.
"""
import torch
from monai.utils.module import optional_import, version_geq
if not (torch.version.cuda and version_geq(f"{torch.version.cuda}", "11.0")):
return False
pynvml, has_pynvml = optional_import("pynvml")
if not has_pynvml: # assuming that the user has Ampere and later GPU
return True
try:
pynvml.nvmlInit()
for i in range(pynvml.nvmlDeviceGetCount()):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
major, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
if major >= 8:
return True
except BaseException:
pass
finally:
pynvml.nvmlShutdown()
return False
@functools.lru_cache(None)
def detect_default_tf32() -> bool:
"""
Detect if there is anything that may enable TF32 mode by default.
If any, show a warning message.
"""
may_enable_tf32 = False
try:
if not has_ampere_or_later():
return False
from monai.utils.module import pytorch_after
if pytorch_after(1, 7, 0) and not pytorch_after(1, 12, 0):
warnings.warn(
"torch.backends.cuda.matmul.allow_tf32 = True by default.\n"
" This value defaults to True when PyTorch version in [1.7, 1.11] and may affect precision.\n"
" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
)
may_enable_tf32 = True
override_tf32_env_vars = {"NVIDIA_TF32_OVERRIDE": "1"} # TORCH_ALLOW_TF32_CUBLAS_OVERRIDE not checked #6907
for name, override_val in override_tf32_env_vars.items():
if os.environ.get(name) == override_val:
warnings.warn(
f"Environment variable `{name} = {override_val}` is set.\n"
f" This environment variable may enable TF32 mode accidentally and affect precision.\n"
f" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
)
may_enable_tf32 = True
return may_enable_tf32
except BaseException:
from monai.utils.misc import MONAIEnvVars
if MONAIEnvVars.debug():
raise
return False