Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
refactor for better readibility
  • Loading branch information
petrex committed Jan 15, 2025
1 parent b96196b commit aea9d81
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def use_debug_mode():
CUDAExtension,
)


IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)

# Constant known variables used throughout this file
Expand Down Expand Up @@ -258,38 +257,41 @@ def get_extensions():
]
)

# Get base directory and source paths
this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

extensions_hip_dir = os.path.join(
extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin"
)
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)
# Collect C++ source files
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

if not IS_ROCM and use_cuda:
sources += cuda_sources

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
# Add ROCm GPU architecture check
gpu_arch = torch.cuda.get_device_properties(0).name
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print(
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
# Collect CUDA source files if needed
if use_cuda:
if not IS_ROCM:
# Regular CUDA sources
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)
sources += cuda_sources
else:
# ROCm sources
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)
return None
sources += hip_sources

if len(sources) == 0:
# Check ROCm GPU architecture compatibility
gpu_arch = torch.cuda.get_device_properties(0).name
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print(
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
)
return None
sources += hip_sources

# Return None if no sources found
if not sources:
return None

ext_modules = []
Expand All @@ -304,7 +306,6 @@ def get_extensions():
)
)


if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
Expand Down

0 comments on commit aea9d81

Please sign in to comment.