Skip to content

Commit

Permalink
update copy from global to lds
Browse files Browse the repository at this point in the history
  • Loading branch information
lcskrishna authored and petrex committed Jan 6, 2025
1 parent f23b194 commit ecc3927
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 23 deletions.
62 changes: 40 additions & 22 deletions torchao/csrc/cuda/sparse_marlin/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
#include "base.h"

namespace torchao {

#ifdef USE_ROCM
#include <hip/hip_runtime.h>

// utility function for ROCm for equivalent for cvta_to_shared.
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) {
return (uint32_t)(uint64_t)(ptr);
}
#endif

// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
Expand All @@ -28,14 +39,16 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
#ifdef USE_ROCM
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
//asm volatile(
// "{\n"
// " .reg .pred p;\n"
// " setp.ne.b32 p, %0, 0;\n"
// " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent
// "}\n" ::"r"((int)pred),
// "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
uint32_t smem = cvta_to_shared(smem_ptr);
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
Expand All @@ -52,14 +65,16 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
#ifdef USE_ROCM
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr));
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
//asm volatile(
// "{\n"
// " .reg .pred p;\n"
// " setp.ne.b32 p, %0, 0;\n"
// " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent
// "}\n" ::"r"((int)pred),
// "r"(smem), "l"(glob_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
Expand All @@ -76,12 +91,15 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
#ifdef USE_ROCM
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
asm volatile(
"{\n"
" ds_read_b128 %0, %1 offset:0;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr));
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
//asm volatile(
// "{\n"
// " ds_read_b128 %0, %1 offset:0;\n"
// "}\n" ::"r"(smem),
// "l"(glob_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);

#else
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
Expand Down
5 changes: 4 additions & 1 deletion torchao/csrc/cuda/sparse_marlin/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

#pragma once
#include "base.h"

#ifndef USE_ROCM
#include <cudaTypedefs.h>
#endif

namespace torchao {

Expand Down Expand Up @@ -259,4 +262,4 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
#endif
}

} // namespace torchao
} // namespace torchao

0 comments on commit ecc3927

Please sign in to comment.