Skip to content

Commit

Permalink
consolidate code with cvta_to_shared()
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Jan 6, 2025
1 parent a80730b commit d2c7ce4
Showing 1 changed file with 9 additions and 39 deletions.
48 changes: 9 additions & 39 deletions torchao/csrc/cuda/sparse_marlin/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
#ifdef USE_ROCM
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
//asm volatile(
// "{\n"
// " .reg .pred p;\n"
// " setp.ne.b32 p, %0, 0;\n"
// " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent
// "}\n" ::"r"((int)pred),
// "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
Expand All @@ -75,19 +66,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
#ifdef USE_ROCM
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
//asm volatile(
// "{\n"
// " .reg .pred p;\n"
// " setp.ne.b32 p, %0, 0;\n"
// " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent
// "}\n" ::"r"((int)pred),
// "r"(smem), "l"(glob_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
Expand All @@ -101,18 +83,10 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
#ifdef USE_ROCM
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
//asm volatile(
// "{\n"
// " ds_read_b128 %0, %1 offset:0;\n"
// "}\n" ::"r"(smem),
// "l"(glob_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);

#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
Expand Down Expand Up @@ -146,17 +120,15 @@ __device__ inline void cp_async_wait() {
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
#ifdef USE_ROCM
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
#else
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
Expand All @@ -165,14 +137,13 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {

__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
asm volatile(
"ds_read_b64 %0, %2 offset:0\n"
: "=v"(a[0]), "=v"(a[1])
: "v"(smem));
#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
Expand All @@ -183,15 +154,14 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
asm volatile(
"ds_read_b128 %0, %4 offset:0\n"
"ds_read_b128 %2, %4 offset:16\n"
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
Expand Down

0 comments on commit d2c7ce4

Please sign in to comment.