From aea9d81a34871d01d04b1563a1208d7070d307af Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 15 Jan 2025 15:09:16 -0800 Subject: [PATCH] lint refactor for better readibility --- setup.py | 57 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 0f64e6107e..d9b3c7e562 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,6 @@ def use_debug_mode(): CUDAExtension, ) - IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) # Constant known variables used throughout this file @@ -258,38 +257,41 @@ def get_extensions(): ] ) + # Get base directory and source paths this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list( - glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) - ) - - extensions_hip_dir = os.path.join( - extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin" - ) - hip_sources = list( - glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) - ) + # Collect C++ source files + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - if not IS_ROCM and use_cuda: - sources += cuda_sources - - # TOOD: Remove this and use what CUDA has once we fix all the builds. - if IS_ROCM and use_cuda: - # Add ROCm GPU architecture check - gpu_arch = torch.cuda.get_device_properties(0).name - if gpu_arch != "gfx942": - print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print( - "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + # Collect CUDA source files if needed + if use_cuda: + if not IS_ROCM: + # Regular CUDA sources + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list( + glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) + ) + sources += cuda_sources + else: + # ROCm sources + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") + hip_sources = list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) - return None - sources += hip_sources - if len(sources) == 0: + # Check ROCm GPU architecture compatibility + gpu_arch = torch.cuda.get_device_properties(0).name + if gpu_arch != "gfx942": + print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) + return None + sources += hip_sources + + # Return None if no sources found + if not sources: return None ext_modules = [] @@ -304,7 +306,6 @@ def get_extensions(): ) ) - if build_torchao_experimental: ext_modules.append( CMakeExtension(