From ecc39275772cfd81046d86a7a749f43c48caa231 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 22 Oct 2024 13:30:45 +0000 Subject: [PATCH] update copy from global to lds --- torchao/csrc/cuda/sparse_marlin/mem.h | 62 +++++++++++++++++---------- torchao/csrc/cuda/sparse_marlin/mma.h | 5 ++- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 0a3f980f44..54f38fb358 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -19,6 +19,17 @@ #include "base.h" namespace torchao { + +#ifdef USE_ROCM +#include + +// utility function for ROCm for equivalent for cvta_to_shared. +template +__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) { + return (uint32_t)(uint64_t)(ptr); +} +#endif + // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. __device__ inline void cp_async4_pred_zfill(void* smem_ptr, @@ -28,14 +39,16 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + //asm volatile( + // "{\n" + // " .reg .pred p;\n" + // " setp.ne.b32 p, %0, 0;\n" + // " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent + // "}\n" ::"r"((int)pred), + // "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + uint32_t smem = cvta_to_shared(smem_ptr); + __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -52,14 +65,16 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr)); + //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + //asm volatile( + // "{\n" + // " .reg .pred p;\n" + // " setp.ne.b32 p, %0, 0;\n" + // " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent + // "}\n" ::"r"((int)pred), + // "r"(smem), "l"(glob_ptr)); + uint32_t smem = cvta_to_shared(smem_ptr); + __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -76,12 +91,15 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - asm volatile( - "{\n" - " ds_read_b128 %0, %1 offset:0;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr)); + //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + //asm volatile( + // "{\n" + // " ds_read_b128 %0, %1 offset:0;\n" + // "}\n" ::"r"(smem), + // "l"(glob_ptr)); + uint32_t smem = cvta_to_shared(smem_ptr); + __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index b8da31870b..9e9a9be519 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -17,7 +17,10 @@ #pragma once #include "base.h" + +#ifndef USE_ROCM #include +#endif namespace torchao { @@ -259,4 +262,4 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, #endif } -} // namespace torchao \ No newline at end of file +} // namespace torchao