Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: use f16 as split-k partial output data type #900

Merged
merged 8 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/batch_mla_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int
params.final_lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.partial_o =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset);
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
params.partial_lse =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);

Expand Down
2 changes: 1 addition & 1 deletion csrc/batch_mla_sm90_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
params.final_lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.partial_o =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset);
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
params.partial_lse =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);

Expand Down
92 changes: 47 additions & 45 deletions include/flashinfer/attention/mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,9 @@ __device__ void DevicePersistentMergeStates(
typename KTraits::IdType* merge_packed_offset_end,
typename KTraits::IdType* merge_partial_packed_offset_start,
typename KTraits::IdType* merge_partial_packed_offset_end,
typename KTraits::IdType* merge_partial_stride, float* partial_o, float* partial_lse,
typename KTraits::DTypeO* final_o, float* final_lse, const uint32_t o_stride_n,
const uint32_t o_stride_h, const uint_fastdiv& num_heads) {
typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o,
float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse,
const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) {
constexpr uint32_t VEC_SIZE = 4; // partial o has data type float
constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE;
constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW;
Expand All @@ -645,8 +645,8 @@ __device__ void DevicePersistentMergeStates(
partial_packed_offset < partial_offset_end; partial_packed_offset += stride) {
vec_t<float, VEC_SIZE> o_partial;
float lse_partial;
o_partial.load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV +
(thread_id % NUM_THRS_PER_ROW) * VEC_SIZE);
o_partial.cast_load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV +
(thread_id % NUM_THRS_PER_ROW) * VEC_SIZE);
lse_partial = partial_lse[partial_packed_offset];
st.merge(o_partial, lse_partial, 1);
}
Expand All @@ -662,17 +662,38 @@ __device__ void DevicePersistentMergeStates(
template <typename KTraits>
__device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage,
typename KTraits::DTypeO* final_o, float* final_lse,
float* partial_o, float* partial_lse, float (*o_frag)[8],
typename KTraits::DTypeQKAccum* m, float* d,
const uint32_t o_stride_n, const uint32_t o_stride_h,
const uint32_t q_len, const uint32_t packed_offset,
typename KTraits::DTypeO* partial_o, float* partial_lse,
float (*o_frag)[8], typename KTraits::DTypeQKAccum* m,
float* d, const uint32_t o_stride_n,
const uint32_t o_stride_h, const uint32_t q_len,
const uint32_t packed_offset,
const uint_fastdiv& num_heads) {
using DTypeO = typename KTraits::DTypeO;
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV;
constexpr uint32_t UPCAST_STRIDE_FINAL_O = KTraits::UPCAST_STRIDE_FINAL_O;
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
smem_t<KTraits::SWIZZLE_MODE_O> o_smem(smem_storage->o_smem);
#pragma unroll
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
uint32_t o_frag_f16[8 / 2];
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]);
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
warp_idx_in_wg * 16 + lane_idx % 16,
warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16);
o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
#else
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2);
((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] =
o_frag_f16[1];
((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] =
o_frag_f16[3];
#endif
}

if (partial_o != nullptr) {
// write to partial_o
Expand All @@ -685,24 +706,27 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st
}
}

// step 1. smem to gmem
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8);
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads;
for (uint32_t j = 0; j < 4; ++j) {
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads;
DTypeO* o_partial_ptr =
partial_o +
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV +
warpgroup_idx * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size<DTypeO>();
#pragma unroll
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 8; ++mma_d) {
if (q_idx < q_len) {
*reinterpret_cast<float2*>(
partial_o +
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV +
warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + (lane_idx % 4) * 2) =
*reinterpret_cast<float2*>(&o_frag[mma_d][j * 2]);
*reinterpret_cast<float2*>(
partial_o +
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV +
warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + 8 + (lane_idx % 4) * 2) =
*reinterpret_cast<float2*>(&o_frag[mma_d][4 + j * 2]);
o_smem.template store_128b(o_smem_offset_w, o_partial_ptr);
}
o_partial_ptr += 8 * upcast_size<DTypeO>();
o_smem_offset_w = o_smem.template advance_offset_by_column<8>(o_smem_offset_w, mma_d);
}
o_smem_offset_w =
o_smem.template advance_offset_by_row<4, UPCAST_STRIDE_FINAL_O>(o_smem_offset_w) -
NUM_MMA_D_CKV;
}
} else {
// write to final_o
Expand All @@ -718,28 +742,6 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st
}
}

// step 0. rmem to smem
#pragma unroll
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
uint32_t o_frag_f16[8 / 2];
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]);
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
warp_idx_in_wg * 16 + lane_idx % 16,
warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16);
o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
#else
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2);
((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] =
o_frag_f16[1];
((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) +
8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = o_frag_f16[3];
#endif
}

// step 1. smem to gmem
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8);
Expand Down Expand Up @@ -794,7 +796,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
DTypeKV* ckv = params.ckv;
DTypeKV* kpe = params.kpe;
IdType* kv_indices = params.kv_indices;
float* partial_o = params.partial_o;
DTypeO* partial_o = params.partial_o;
float* partial_lse = params.partial_lse;
DTypeO* final_o = params.final_o;
float* final_lse = params.final_lse;
Expand Down
99 changes: 52 additions & 47 deletions include/flashinfer/attention/mla_hopper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,14 @@ __device__ __forceinline__ void normalize_d_(typename KTraits::SharedStorage* sm
}

template <typename KTraits>
__device__ __forceinline__ void write_o(
typename KTraits::SharedStorage* smem_storage, const uint32_t stage_counter,
typename KTraits::DTypeO* final_o, float* final_lse, float* partial_o, float* partial_lse,
float(*o_frag), float* m, float* d, const uint32_t o_stride_n, const uint32_t o_stride_h,
const uint32_t q_len, const uint32_t packed_offset, const uint_fastdiv& num_heads) {
__device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage,
const uint32_t stage_counter,
typename KTraits::DTypeO* final_o, float* final_lse,
typename KTraits::DTypeO* partial_o, float* partial_lse,
float(*o_frag), float* m, float* d,
const uint32_t o_stride_n, const uint32_t o_stride_h,
const uint32_t q_len, const uint32_t packed_offset,
const uint_fastdiv& num_heads) {
using DTypeO = typename KTraits::DTypeO;
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV;
Expand All @@ -457,61 +460,52 @@ __device__ __forceinline__ void write_o(
o_smem[0] = smem_storage->kv_o_smem[stage_counter % KTraits::NUM_STAGES].o;
o_smem[1] = smem_storage->kv_o_smem[(stage_counter + 1) % KTraits::NUM_STAGES].o;

// step 0. rmem to smem
#pragma unroll
for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) {
uint32_t o_frag_f16[8 / 2];
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, &o_frag[k * 8]);
uint32_t o_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_O, UPCAST_STRIDE_FINAL_O>(
(warp_idx_in_wg % 2) * 16 + lane_idx % 16,
(warp_group_idx - 1) * NUM_MMA_D_CKV + k * 2 + lane_idx / 16);
o_smem[warp_idx_in_wg / 2].template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
}

if (partial_o != nullptr) {
// NOTE(Zihao): o_smem is not used if write to partial_o, and we can avoid the barrier
barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO);
// write to partial_o

#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads;
if (lane_idx % 4 == 0 && q_idx < q_len) {
partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] =
math::ptx_log2(d[j]) + float(m[j]);
for (uint32_t j = 0; j < 4; ++j) {
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads;
DTypeO* o_partial_ptr =
partial_o +
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV +
(warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size<DTypeO>();
uint32_t o_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_O, UPCAST_STRIDE_FINAL_O>(
(warp_idx_in_wg % 2) * 16 + 4 * j + lane_idx / 8,
(warp_group_idx - 1) * NUM_MMA_D_CKV + lane_idx % 8);
#pragma unroll
for (uint32_t k = 0; k < HEAD_DIM_CKV / 128; ++k) {
if (q_idx < q_len) {
o_smem[warp_idx_in_wg / 2].template store_128b(o_smem_offset_w, o_partial_ptr);
}
o_partial_ptr += 8 * upcast_size<DTypeO>();
o_smem_offset_w += 64;
}
}
barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO);

#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads;
#pragma unroll
for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) {
if (q_idx < q_len) {
*reinterpret_cast<float2*>(
partial_o +
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV +
(warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + k * 16 + (lane_idx % 4) * 2) =
*reinterpret_cast<float2*>(&o_frag[k * 8 + j * 2]);
*reinterpret_cast<float2*>(
partial_o +
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV +
(warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + k * 16 + 8 + (lane_idx % 4) * 2) =
*reinterpret_cast<float2*>(&o_frag[k * 8 + 4 + j * 2]);
}
if (lane_idx % 4 == 0 && q_idx < q_len) {
partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] =
math::ptx_log2(d[j]) + float(m[j]);
}
}
} else {
// write to final_o
if (final_lse) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q, r;
num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r);
if (lane_idx % 4 == 0 && q < q_len) {
final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]);
}
}
}

// step 0. rmem to smem
#pragma unroll
for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) {
uint32_t o_frag_f16[8 / 2];
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, &o_frag[k * 8]);
uint32_t o_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_O, UPCAST_STRIDE_FINAL_O>(
(warp_idx_in_wg % 2) * 16 + lane_idx % 16,
(warp_group_idx - 1) * NUM_MMA_D_CKV + k * 2 + lane_idx / 16);
o_smem[warp_idx_in_wg / 2].template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
}

// step 1. smem to gmem
#pragma unroll
Expand All @@ -534,6 +528,17 @@ __device__ __forceinline__ void write_o(
}
}
barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO);

if (final_lse) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q, r;
num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r);
if (lane_idx % 4 == 0 && q < q_len) {
final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]);
}
}
}
}
}

Expand Down Expand Up @@ -586,7 +591,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
DTypeKV* ckv = params.ckv;
DTypeKV* kpe = params.kpe;
IdType* kv_indices = params.kv_indices;
float* partial_o = params.partial_o;
DTypeO* partial_o = params.partial_o;
float* partial_lse = params.partial_lse;
DTypeO* final_o = params.final_o;
float* final_lse = params.final_lse;
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/mla_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct MLAParams {
DTypeQ* q_pe;
DTypeKV* ckv;
DTypeKV* kpe;
float* partial_o;
DTypeO* partial_o;
float* partial_lse;
DTypeO* final_o;
float* final_lse;
Expand Down
19 changes: 17 additions & 2 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,21 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
total_kv_lens += effective_kv_len;
}
}
int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 256L) * 256L;

auto f = [](int x) {
if (x <= 8) {
return 32;
} else if (x <= 16) {
return 64;
} else if (x <= 32) {
return 128;
} else if (x <= 64) {
return 192;
}
return ceil_div(x, 256) * 256;
};

int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L));

// step 1. load-balancing scheduling algorithm
MinHeap cluster_cost_heap(num_clusters);
Expand Down Expand Up @@ -1295,9 +1309,10 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));

constexpr size_t sizeof_dtype_o = 2;
AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes);
plan_info.partial_o_offset = float_allocator.aligned_alloc_offset(
2 * num_clusters * cluster_tile_q * sizeof(float) * head_dim_o, 16, "mla_partial_o");
2 * num_clusters * cluster_tile_q * sizeof_dtype_o * head_dim_o, 16, "mla_partial_o");
plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset(
2 * num_clusters * cluster_tile_q * sizeof(float), 16, "mla_partial_lse");

Expand Down