Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Jan 8, 2025
1 parent 8271d05 commit 662bfe7
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,27 @@ def read_version(file_path="version.txt"):

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
if debug_mode:
print("Compiling in debug mode")

if not torch.cuda.is_available():
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions")
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")
print(
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
)
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)

use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None)
use_cuda = torch.cuda.is_available() and (
CUDA_HOME is not None or ROCM_HOME is not None
)
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
Expand Down Expand Up @@ -125,8 +134,12 @@ def get_extensions():
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin")
hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True))
extensions_hip_dir = os.path.join(
extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin"
)
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)

if not IS_ROCM and use_cuda:
sources += cuda_sources
Expand Down

0 comments on commit 662bfe7

Please sign in to comment.