Skip to content

Commit

Permalink
implement cvta_to_shared()
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Jan 6, 2025
1 parent ecc3927 commit a80730b
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions torchao/csrc/cuda/sparse_marlin/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@ namespace torchao {
#ifdef USE_ROCM
#include <hip/hip_runtime.h>

// utility function for ROCm for equivalent for cvta_to_shared.
// Convert generic pointer to shared memory address for ROCm
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) {
return (uint32_t)(uint64_t)(ptr);
__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) {
// First get the address as a size_t to handle all pointer sizes
size_t addr = reinterpret_cast<size_t>(ptr);

// Extract the lower 32 bits which represent the shared memory offset
// This is safe because shared memory addresses are always within 32-bit range
return static_cast<uint32_t>(addr & 0xFFFFFFFF);
}
#else
// For CUDA, use the native intrinsic
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}
#endif

Expand Down

0 comments on commit a80730b

Please sign in to comment.