Skip to content

Commit

Permalink
Merge branch 'rocm_enablement_staging' into rocm_sparse_marlin
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex authored Oct 29, 2024
2 parents 5c7d77b + 91d3c75 commit 00bc94d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
28 changes: 10 additions & 18 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def get_extensions():
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))


extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")

hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True))

if not IS_ROCM and use_cuda:
Expand All @@ -119,24 +121,14 @@ def get_extensions():
sources += hip_sources

## TODO: remove this condition and use what we have in CUDA once we fix the individual builds.
if not IS_ROCM:
ext_modules = [
extension(
"torchao._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
else:
ext_modules = [
extension(
"torchao._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
ext_modules = [
extension(
"torchao._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]

return ext_modules

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
Expand All @@ -7,13 +7,24 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#endif

template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
const uint64_t blocks = a / b + (a % b != 0);
return blocks;
}

#if defined(USE_ROCM)
constexpr int32_t kWarpSize = 64;
#else
constexpr int32_t kWarpSize = 32;
#endif

//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
Expand All @@ -30,38 +41,68 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
uint32_t const source_i4s = source;

// First, we extract the i4s and construct an intermediate fp16 number.
#if !defined(USE_ROCM)
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
#endif
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;

// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;

#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[0])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif

#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[ii])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
}

// This is the BF16 {-136, -136} represented as an integer.
#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
#else
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
#endif
#else
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
#endif

// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
#if defined(USE_ROCM)
result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
#else
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
#endif
}

return result;
Expand Down Expand Up @@ -123,11 +164,16 @@ __global__ void _dequantize_int4_kernel(
// All b values within a 16x16 tile should fall within the same q group
// Hence we load 1 scale and zero per loop
int qgroup = ks[0] / groupSize;
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);

// Vectorize scales and zeros
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
__nv_bfloat162 scale2 = {1.0f, 1.0f};
__nv_bfloat162 zero2 = {1.0f, 1.0f};

if (scales_and_zeros) {
const auto& sz = *scales_and_zeros;
const __nv_bfloat16* pSZ = reinterpret_cast<const __nv_bfloat16*>(&sz[qgroup][n0][0]);

scale2 = __bfloat162bfloat162(pSZ[0]);
zero2 = __bfloat162bfloat162(pSZ[1]);
}

#pragma unroll
for (int i = 0; i < 4; i++) {
Expand Down

0 comments on commit 00bc94d

Please sign in to comment.