Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm Sparse Marlin Kernels #1206

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6d92e40
enable build for rocm for fp6_llm
lcskrishna Oct 16, 2024
14b3fce
Merge pull request #1 from lcskrishna/cl/rocm-enablement
petrex Oct 17, 2024
f1a22cf
enable tiled layout extension
lcskrishna Oct 23, 2024
0bef6ca
fix build error related to option
petrex Oct 23, 2024
893ae03
require rocm 6.2
petrex Oct 23, 2024
a0d3788
enable tensor tiled layout extension with successful compilation
lcskrishna Oct 24, 2024
e4e654d
enable successful build
lcskrishna Oct 24, 2024
3e2c6a1
clean-up
petrex Oct 29, 2024
c86880e
Merge pull request #3 from lcskrishna/csrikris_enable_tensor_tile
petrex Oct 29, 2024
91d3c75
fix potential memory access issue
petrex Oct 29, 2024
38b7d1c
fix __nv_bfloat162 init
petrex Nov 12, 2024
279f4b3
add comment for MI300x isa
petrex Nov 12, 2024
612ad14
Merge branch 'main' into rocm_enablement_staging
petrex Nov 18, 2024
bbf5a72
fix build for non-rocm
lcskrishna Jan 6, 2025
735570e
Merge pull request #4 from lcskrishna/rocm_enablement
petrex Jan 6, 2025
253c188
Merge branch 'main' into rocm_enablement_staging
petrex Jan 6, 2025
a2f1736
add sparse_marlin kernel to the build
petrex Oct 17, 2024
f817edf
drop .h from conversion
petrex Oct 17, 2024
c9bc1bc
cp_asyc4_pred_zfill() AMD implementation
petrex Oct 17, 2024
16feff4
implement matching mem utility with amd GCN isa
petrex Oct 18, 2024
0b21555
implement mma util with amd gcn isa
petrex Oct 18, 2024
f23b194
enable rocm path
petrex Oct 18, 2024
ecc3927
update copy from global to lds
lcskrishna Oct 22, 2024
a80730b
implement cvta_to_shared()
petrex Oct 23, 2024
d2c7ce4
consolidate code with cvta_to_shared()
petrex Oct 23, 2024
15974c7
Merge branch 'main' into rocm_sparse_marlin
petrex Jan 8, 2025
a4e8c30
lint
petrex Jan 8, 2025
c678cb0
add GPU arch check for MI300x
petrex Jan 9, 2025
08d1cfb
revert change in tensor_core_tile_layout.cu
petrex Jan 9, 2025
b96196b
Merge branch 'main' into rocm_sparse_marlin
petrex Jan 15, 2025
aea9d81
lint
petrex Jan 15, 2025
f18043d
Merge branch 'main' into rocm_sparse_marlin
petrex Feb 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ def read_version(file_path="version.txt"):
from torch.utils.cpp_extension import (
CUDA_HOME,
IS_WINDOWS,
ROCM_HOME,
BuildExtension,
CppExtension,
CUDAExtension,
)

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
Expand All @@ -68,13 +71,17 @@ def get_extensions():
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
print(
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
)
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
use_cuda = torch.cuda.is_available() and (
CUDA_HOME is not None or ROCM_HOME is not None
)
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
Expand All @@ -93,7 +100,8 @@ def get_extensions():

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
if "nvcc" in extra_compile_args:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
else:
extra_compile_args["cxx"] = ["/O2" if not debug_mode else "/Od", "/permissive-"]
Expand Down Expand Up @@ -126,9 +134,20 @@ def get_extensions():
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

if use_cuda:
extensions_hip_dir = os.path.join(
extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin"
)
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)

if not IS_ROCM and use_cuda:
sources += cuda_sources

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
sources += hip_sources

if len(sources) == 0:
return None

Expand Down
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static constexpr int min_thread_n = 128;
static constexpr int tile_size = 16;
static constexpr int max_par = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && !defined(USE_ROCM)

template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock
Expand Down
102 changes: 93 additions & 9 deletions torchao/csrc/cuda/sparse_marlin/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,28 @@
#include "base.h"

namespace torchao {

#ifdef USE_ROCM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we gate on a specific ROCm version like we do for CUDA?

Copy link
Collaborator Author

@petrex petrex Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! What we need is a GPU arch check instead of ROCm version check. I have added a GPU architecture check in the setup.py . As a result, the kernel will now only be built for the MI300X architecture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I think setup.py was recently updated by #1490, so you may have to pull in the new changes.

#include <hip/hip_runtime.h>

// Convert generic pointer to shared memory address for ROCm
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) {
// First get the address as a size_t to handle all pointer sizes
size_t addr = reinterpret_cast<size_t>(ptr);

// Extract the lower 32 bits which represent the shared memory offset
// This is safe because shared memory addresses are always within 32-bit range
return static_cast<uint32_t>(addr & 0xFFFFFFFF);
}
#else
// For CUDA, use the native intrinsic
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}
#endif

// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
Expand All @@ -27,91 +49,144 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
#endif
}

__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
#endif
}

// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
#endif
}

// Async copy fence.
__device__ inline void cp_async_fence() {
#ifdef USE_ROCM
__builtin_amdgcn_s_waitcnt(0);
#else
asm volatile("cp.async.commit_group;\n" ::);
#endif
}

// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
#ifdef USE_ROCM
// For AMD GPUs, we use s_waitcnt
// This waits for all outstanding memory operations to complete
__builtin_amdgcn_s_waitcnt(0);
#else
// For NVIDIA GPUs, use the original instruction
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
#endif
}

// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
#else
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
#endif
}

__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b64 %0, %2 offset:0\n"
: "=v"(a[0]), "=v"(a[1])
: "v"(smem));
#else
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
#endif
}

// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
#else
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
#endif
}

// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
do {
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
#ifdef USE_ROCM
asm volatile("flat_load_dword %0, %1 glc\n\t"
"s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t"
: "=v"(state)
: "v"(lock));
#else
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
#endif
} while (state != count);
}
__syncthreads();
}
Expand All @@ -127,10 +202,19 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
#ifdef USE_ROCM
asm volatile("s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t"
"s_memrealtime\n\t"
"s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t"
"flat_atomic_add_i32 %0, %1\n\t"
: "+v"(*lock)
: "v"(val));
#else
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
#endif
}
}
} // namespace torchao
} // namespace torchao
Loading
Loading