From 7273179699db50222d5dac0024457374afd82514 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 26 Feb 2025 04:39:52 +0000 Subject: [PATCH 1/8] upd --- benchmarks/bench_deepseek_mla.py | 8 +- csrc/batch_mla_run.cu | 2 +- csrc/batch_mla_sm90_run.cu | 2 +- include/flashinfer/attention/mla.cuh | 100 +++++++++++--------- include/flashinfer/attention/mla_hopper.cuh | 97 ++++++++++--------- include/flashinfer/attention/mla_params.cuh | 2 +- include/flashinfer/attention/scheduler.cuh | 5 +- tests/test_deepseek_mla.py | 1 + 8 files changed, 115 insertions(+), 102 deletions(-) diff --git a/benchmarks/bench_deepseek_mla.py b/benchmarks/bench_deepseek_mla.py index 9fa14a536..bc6ce40b5 100644 --- a/benchmarks/bench_deepseek_mla.py +++ b/benchmarks/bench_deepseek_mla.py @@ -69,11 +69,13 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend): io = sum([_.numel() * _.element_size() for _ in [q_nope, q_pe, ckv, kpe, o]]) + print(ms) print(f"Config: batch_size={batch_size}, seq_len={seq_len}, num_heads={num_heads}") print(f"Memory bandwidth: {io * 1e-6 / ms:.2f} GB/s") if __name__ == "__main__": - for seq_len in [1024, 2048]: - for batch_size in [64, 128, 768]: - bench_deepseek_mla_decode(batch_size, seq_len, 64, "auto") + for seq_len in [16384]: + for batch_size in [64]: # [64, 128, 768]: + for num_heads in [128]: + bench_deepseek_mla_decode(batch_size, seq_len, num_heads, "auto") diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index 9e624ac83..af44cb34a 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -95,7 +95,7 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int params.final_lse = maybe_lse.has_value() ? static_cast(maybe_lse->data_ptr()) : nullptr; params.partial_o = - GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params.partial_lse = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); diff --git a/csrc/batch_mla_sm90_run.cu b/csrc/batch_mla_sm90_run.cu index 44fb19070..200d8bf6c 100644 --- a/csrc/batch_mla_sm90_run.cu +++ b/csrc/batch_mla_sm90_run.cu @@ -96,7 +96,7 @@ void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer, params.final_lse = maybe_lse.has_value() ? static_cast(maybe_lse->data_ptr()) : nullptr; params.partial_o = - GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params.partial_lse = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index 171bdef7b..deb54d110 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -41,6 +41,7 @@ struct StandardAttention : AttentionVariantBase { template struct SharedStorageQKVO { + static constexpr uint32_t MERGE_SMEM_STAGES = 8; union { struct { alignas(16) DTypeQ q_smem_nope[CTA_TILE_Q * HEAD_DIM_CKV]; @@ -55,6 +56,10 @@ struct SharedStorageQKVO { }; }; alignas(16) DTypeO o_smem[CTA_TILE_Q * HEAD_DIM_CKV]; + struct { + DTypeO merge_o_smem[MERGE_SMEM_STAGES][HEAD_DIM_CKV]; + float merge_lse_smem[]; + }; }; }; @@ -620,9 +625,9 @@ __device__ void DevicePersistentMergeStates( typename KTraits::IdType* merge_packed_offset_end, typename KTraits::IdType* merge_partial_packed_offset_start, typename KTraits::IdType* merge_partial_packed_offset_end, - typename KTraits::IdType* merge_partial_stride, float* partial_o, float* partial_lse, - typename KTraits::DTypeO* final_o, float* final_lse, const uint32_t o_stride_n, - const uint32_t o_stride_h, const uint_fastdiv& num_heads) { + typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o, + float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse, + const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) { constexpr uint32_t VEC_SIZE = 4; // partial o has data type float constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE; constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW; @@ -640,17 +645,17 @@ __device__ void DevicePersistentMergeStates( uint32_t q, r; num_heads.divmod(final_packed_offset, q, r); state_t st; -#pragma unroll 8 +#pragma unroll 4 for (uint32_t partial_packed_offset = partial_offset_start + local_packed_offset; partial_packed_offset < partial_offset_end; partial_packed_offset += stride) { vec_t o_partial; float lse_partial; - o_partial.load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV + - (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE); + o_partial.cast_load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV + + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE); lse_partial = partial_lse[partial_packed_offset]; st.merge(o_partial, lse_partial, 1); } - st.normalize(); + // st.normalize(); st.o.cast_store(final_o + (q * o_stride_n + r * o_stride_h + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE)); if (final_lse) { @@ -662,10 +667,11 @@ __device__ void DevicePersistentMergeStates( template __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeO* final_o, float* final_lse, - float* partial_o, float* partial_lse, float (*o_frag)[8], - typename KTraits::DTypeQKAccum* m, float* d, - const uint32_t o_stride_n, const uint32_t o_stride_h, - const uint32_t q_len, const uint32_t packed_offset, + typename KTraits::DTypeO* partial_o, float* partial_lse, + float (*o_frag)[8], typename KTraits::DTypeQKAccum* m, + float* d, const uint32_t o_stride_n, + const uint32_t o_stride_h, const uint32_t q_len, + const uint32_t packed_offset, const uint_fastdiv& num_heads) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV; @@ -673,6 +679,26 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st constexpr uint32_t UPCAST_STRIDE_FINAL_O = KTraits::UPCAST_STRIDE_FINAL_O; const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y; smem_t o_smem(smem_storage->o_smem); +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) { + uint32_t o_frag_f16[8 / 2]; + vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]); +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( + warp_idx_in_wg * 16 + lane_idx % 16, + warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16); + o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( + warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2); + ((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; + ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = + o_frag_f16[3]; +#endif + } if (partial_o != nullptr) { // write to partial_o @@ -685,24 +711,26 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st } } + // step 1. smem to gmem + uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( + warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8); #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; + for (uint32_t j = 0; j < 4; ++j) { + uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; + DTypeO* o_partial_ptr = + partial_o + (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + warpgroup_idx * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 8; ++mma_d) { if (q_idx < q_len) { - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[mma_d][j * 2]); - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + 8 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[mma_d][4 + j * 2]); + o_smem.template store_128b(o_smem_offset_w, o_partial_ptr); } + o_partial_ptr += 8 * upcast_size(); + o_smem_offset_w = o_smem.template advance_offset_by_column<8>(o_smem_offset_w, mma_d); } + o_smem_offset_w = + o_smem.template advance_offset_by_row<4, UPCAST_STRIDE_FINAL_O>(o_smem_offset_w) - + NUM_MMA_D_CKV; } } else { // write to final_o @@ -718,28 +746,6 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st } } - // step 0. rmem to smem -#pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) { - uint32_t o_frag_f16[8 / 2]; - vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]); -#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( - warp_idx_in_wg * 16 + lane_idx % 16, - warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16); - o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); -#else - uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( - warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2); - ((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; - ((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = - o_frag_f16[1]; - ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; - ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) + - 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = o_frag_f16[3]; -#endif - } - // step 1. smem to gmem uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8); @@ -794,7 +800,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe DTypeKV* ckv = params.ckv; DTypeKV* kpe = params.kpe; IdType* kv_indices = params.kv_indices; - float* partial_o = params.partial_o; + DTypeO* partial_o = params.partial_o; float* partial_lse = params.partial_lse; DTypeO* final_o = params.final_o; float* final_lse = params.final_lse; diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index 79d65db8a..8d21d28bc 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -441,11 +441,14 @@ __device__ __forceinline__ void normalize_d_(typename KTraits::SharedStorage* sm } template -__device__ __forceinline__ void write_o( - typename KTraits::SharedStorage* smem_storage, const uint32_t stage_counter, - typename KTraits::DTypeO* final_o, float* final_lse, float* partial_o, float* partial_lse, - float(*o_frag), float* m, float* d, const uint32_t o_stride_n, const uint32_t o_stride_h, - const uint32_t q_len, const uint32_t packed_offset, const uint_fastdiv& num_heads) { +__device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage, + const uint32_t stage_counter, + typename KTraits::DTypeO* final_o, float* final_lse, + typename KTraits::DTypeO* partial_o, float* partial_lse, + float(*o_frag), float* m, float* d, + const uint32_t o_stride_n, const uint32_t o_stride_h, + const uint32_t q_len, const uint32_t packed_offset, + const uint_fastdiv& num_heads) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV; constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV; @@ -457,61 +460,51 @@ __device__ __forceinline__ void write_o( o_smem[0] = smem_storage->kv_o_smem[stage_counter % KTraits::NUM_STAGES].o; o_smem[1] = smem_storage->kv_o_smem[(stage_counter + 1) % KTraits::NUM_STAGES].o; + // step 0. rmem to smem +#pragma unroll + for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) { + uint32_t o_frag_f16[8 / 2]; + vec_cast::cast<8>((DTypeO*)o_frag_f16, &o_frag[k * 8]); + uint32_t o_smem_offset_w = get_swizzle_offset( + (warp_idx_in_wg % 2) * 16 + lane_idx % 16, + (warp_group_idx - 1) * NUM_MMA_D_CKV + k * 2 + lane_idx / 16); + o_smem[warp_idx_in_wg / 2].template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); + } + if (partial_o != nullptr) { // NOTE(Zihao): o_smem is not used if write to partial_o, and we can avoid the barrier - barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); // write to partial_o + #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; - if (lane_idx % 4 == 0 && q_idx < q_len) { - partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = - math::ptx_log2(d[j]) + float(m[j]); + for (uint32_t j = 0; j < 4; ++j) { + uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; + DTypeO* o_final_ptr = + final_o + (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); + uint32_t o_smem_offset_w = get_swizzle_offset( + (warp_idx_in_wg % 2) * 16 + 4 * j + lane_idx / 8, + (warp_group_idx - 1) * NUM_MMA_D_CKV + lane_idx % 8); +#pragma unroll + for (uint32_t k = 0; k < HEAD_DIM_CKV / 128; ++k) { + if (q_idx < q_len) { + o_smem[warp_idx_in_wg / 2].template store_128b(o_smem_offset_w, o_final_ptr); + } + o_final_ptr += 8 * upcast_size(); + o_smem_offset_w += 64; } } + barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; -#pragma unroll - for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) { - if (q_idx < q_len) { - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + k * 16 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[k * 8 + j * 2]); - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + k * 16 + 8 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[k * 8 + 4 + j * 2]); - } + if (lane_idx % 4 == 0 && q_idx < q_len) { + partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = + math::ptx_log2(d[j]) + float(m[j]); } } } else { // write to final_o - if (final_lse) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q, r; - num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r); - if (lane_idx % 4 == 0 && q < q_len) { - final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]); - } - } - } - - // step 0. rmem to smem -#pragma unroll - for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) { - uint32_t o_frag_f16[8 / 2]; - vec_cast::cast<8>((DTypeO*)o_frag_f16, &o_frag[k * 8]); - uint32_t o_smem_offset_w = get_swizzle_offset( - (warp_idx_in_wg % 2) * 16 + lane_idx % 16, - (warp_group_idx - 1) * NUM_MMA_D_CKV + k * 2 + lane_idx / 16); - o_smem[warp_idx_in_wg / 2].template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); - } // step 1. smem to gmem #pragma unroll @@ -534,6 +527,16 @@ __device__ __forceinline__ void write_o( } } barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); + if (final_lse) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r); + if (lane_idx % 4 == 0 && q < q_len) { + final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]); + } + } + } } } @@ -586,7 +589,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop DTypeKV* ckv = params.ckv; DTypeKV* kpe = params.kpe; IdType* kv_indices = params.kv_indices; - float* partial_o = params.partial_o; + DTypeO* partial_o = params.partial_o; float* partial_lse = params.partial_lse; DTypeO* final_o = params.final_o; float* final_lse = params.final_lse; diff --git a/include/flashinfer/attention/mla_params.cuh b/include/flashinfer/attention/mla_params.cuh index 9b63e3401..cad1acff7 100644 --- a/include/flashinfer/attention/mla_params.cuh +++ b/include/flashinfer/attention/mla_params.cuh @@ -32,7 +32,7 @@ struct MLAParams { DTypeQ* q_pe; DTypeKV* ckv; DTypeKV* kpe; - float* partial_o; + DTypeO* partial_o; float* partial_lse; DTypeO* final_o; float* final_lse; diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 30046f411..eaa6ea49d 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1104,7 +1104,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by total_kv_lens += effective_kv_len; } } - int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 256L) * 256L; + int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 512L) * 512L; // step 1. load-balancing scheduling algorithm MinHeap cluster_cost_heap(num_clusters); @@ -1295,9 +1295,10 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, cudaMemcpyHostToDevice, stream)); + constexpr size_t sizeof_dtype_o = 2; AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.partial_o_offset = float_allocator.aligned_alloc_offset( - 2 * num_clusters * cluster_tile_q * sizeof(float) * head_dim_o, 16, "mla_partial_o"); + 2 * num_clusters * cluster_tile_q * sizeof_dtype_o * head_dim_o, 16, "mla_partial_o"); plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset( 2 * num_clusters * cluster_tile_q * sizeof(float), 16, "mla_partial_lse"); diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index e93aa1798..d9d491a5c 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -435,6 +435,7 @@ def test_batch_mla_page_attention( q = torch.cat([q_nope, q_pe], dim=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) lse_ref = lse_ref.flatten(0, 1) + print(o, o_ref) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) if kv_len != 0: torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) From 15e8eb229691db76b53941595d159f79c0e125e4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 01:44:30 +0000 Subject: [PATCH 2/8] upd --- include/flashinfer/attention/mla.cuh | 9 ++------- include/flashinfer/attention/mla_hopper.cuh | 14 ++++++++------ include/flashinfer/attention/scheduler.cuh | 2 +- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index deb54d110..6a63ef414 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -41,7 +41,6 @@ struct StandardAttention : AttentionVariantBase { template struct SharedStorageQKVO { - static constexpr uint32_t MERGE_SMEM_STAGES = 8; union { struct { alignas(16) DTypeQ q_smem_nope[CTA_TILE_Q * HEAD_DIM_CKV]; @@ -56,10 +55,6 @@ struct SharedStorageQKVO { }; }; alignas(16) DTypeO o_smem[CTA_TILE_Q * HEAD_DIM_CKV]; - struct { - DTypeO merge_o_smem[MERGE_SMEM_STAGES][HEAD_DIM_CKV]; - float merge_lse_smem[]; - }; }; }; @@ -645,7 +640,7 @@ __device__ void DevicePersistentMergeStates( uint32_t q, r; num_heads.divmod(final_packed_offset, q, r); state_t st; -#pragma unroll 4 +#pragma unroll 8 for (uint32_t partial_packed_offset = partial_offset_start + local_packed_offset; partial_packed_offset < partial_offset_end; partial_packed_offset += stride) { vec_t o_partial; @@ -655,7 +650,7 @@ __device__ void DevicePersistentMergeStates( lse_partial = partial_lse[partial_packed_offset]; st.merge(o_partial, lse_partial, 1); } - // st.normalize(); + st.normalize(); st.o.cast_store(final_o + (q * o_stride_n + r * o_stride_h + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE)); if (final_lse) { diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index 8d21d28bc..de445225a 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -478,8 +478,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st #pragma unroll for (uint32_t j = 0; j < 4; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; - DTypeO* o_final_ptr = - final_o + (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + DTypeO* o_partial_ptr = + partial_o + (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); uint32_t o_smem_offset_w = get_swizzle_offset( (warp_idx_in_wg % 2) * 16 + 4 * j + lane_idx / 8, @@ -487,13 +487,12 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st #pragma unroll for (uint32_t k = 0; k < HEAD_DIM_CKV / 128; ++k) { if (q_idx < q_len) { - o_smem[warp_idx_in_wg / 2].template store_128b(o_smem_offset_w, o_final_ptr); + o_smem[warp_idx_in_wg / 2].template store_128b(o_smem_offset_w, o_partial_ptr); } - o_final_ptr += 8 * upcast_size(); + o_partial_ptr += 8 * upcast_size(); o_smem_offset_w += 64; } } - barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { @@ -503,6 +502,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st math::ptx_log2(d[j]) + float(m[j]); } } + + barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); } else { // write to final_o @@ -526,7 +527,6 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st o_smem_offset_w += 64; } } - barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); if (final_lse) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { @@ -537,6 +537,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st } } } + + barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); } } diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index eaa6ea49d..e739eb080 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1104,7 +1104,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by total_kv_lens += effective_kv_len; } } - int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 512L) * 512L; + int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 256L) * 256L; // step 1. load-balancing scheduling algorithm MinHeap cluster_cost_heap(num_clusters); From 0e9baa60feab8abb707af8d43059e9922fd358f5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 02:10:44 +0000 Subject: [PATCH 3/8] upd --- include/flashinfer/attention/scheduler.cuh | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index e739eb080..ae943a719 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1104,7 +1104,21 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by total_kv_lens += effective_kv_len; } } - int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 256L) * 256L; + + auto f = [](int x) { + if (x <= 8) { + return 32; + } else if (x <= 16) { + return 64; + } else if (x <= 32) { + return 128; + } else if (x <= 64) { + return 192; + } + return ceil_div(x, 256) * 256; + }; + + int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L)); // step 1. load-balancing scheduling algorithm MinHeap cluster_cost_heap(num_clusters); From 2c6f336d8d35d15446593ee7e9681554d4ed3199 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 02:14:52 +0000 Subject: [PATCH 4/8] revert changes to tests and benchmarks --- benchmarks/bench_deepseek_mla.py | 8 +++----- tests/test_deepseek_mla.py | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_deepseek_mla.py b/benchmarks/bench_deepseek_mla.py index bc6ce40b5..9fa14a536 100644 --- a/benchmarks/bench_deepseek_mla.py +++ b/benchmarks/bench_deepseek_mla.py @@ -69,13 +69,11 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend): io = sum([_.numel() * _.element_size() for _ in [q_nope, q_pe, ckv, kpe, o]]) - print(ms) print(f"Config: batch_size={batch_size}, seq_len={seq_len}, num_heads={num_heads}") print(f"Memory bandwidth: {io * 1e-6 / ms:.2f} GB/s") if __name__ == "__main__": - for seq_len in [16384]: - for batch_size in [64]: # [64, 128, 768]: - for num_heads in [128]: - bench_deepseek_mla_decode(batch_size, seq_len, num_heads, "auto") + for seq_len in [1024, 2048]: + for batch_size in [64, 128, 768]: + bench_deepseek_mla_decode(batch_size, seq_len, 64, "auto") diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index d9d491a5c..e93aa1798 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -435,7 +435,6 @@ def test_batch_mla_page_attention( q = torch.cat([q_nope, q_pe], dim=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) lse_ref = lse_ref.flatten(0, 1) - print(o, o_ref) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) if kv_len != 0: torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) From ffa9439a21411d7038f9de26685ef81d2b3daeba Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 03:19:45 +0000 Subject: [PATCH 5/8] upd --- include/flashinfer/attention/mla.cuh | 2 +- include/flashinfer/attention/mla_hopper.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index 6a63ef414..47d68d655 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -701,7 +701,7 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st for (uint32_t j = 0; j < 2; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; if (lane_idx % 4 == 0 && q_idx < q_len) { - partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = + partial_lse[packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4] = math::ptx_log2(d[j]) + float(m[j]); } } diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index de445225a..28a533d24 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -498,7 +498,7 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st for (uint32_t j = 0; j < 2; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; if (lane_idx % 4 == 0 && q_idx < q_len) { - partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = + partial_lse[packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4] = math::ptx_log2(d[j]) + float(m[j]); } } From 96d11b97fed0c25192be68bc52a56f9d22d89f4a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 03:40:34 +0000 Subject: [PATCH 6/8] Revert "upd" This reverts commit ffa9439a21411d7038f9de26685ef81d2b3daeba. --- include/flashinfer/attention/mla.cuh | 2 +- include/flashinfer/attention/mla_hopper.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index 47d68d655..6a63ef414 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -701,7 +701,7 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st for (uint32_t j = 0; j < 2; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; if (lane_idx % 4 == 0 && q_idx < q_len) { - partial_lse[packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4] = + partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = math::ptx_log2(d[j]) + float(m[j]); } } diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index 28a533d24..de445225a 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -498,7 +498,7 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st for (uint32_t j = 0; j < 2; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; if (lane_idx % 4 == 0 && q_idx < q_len) { - partial_lse[packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4] = + partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = math::ptx_log2(d[j]) + float(m[j]); } } From 6ac22301c9d228ddec8b47797713389db89a2c9c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 03:48:44 +0000 Subject: [PATCH 7/8] upd --- include/flashinfer/attention/mla.cuh | 3 ++- include/flashinfer/attention/mla_hopper.cuh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index 6a63ef414..fc0cf239e 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -713,7 +713,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st for (uint32_t j = 0; j < 4; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; DTypeO* o_partial_ptr = - partial_o + (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + partial_o + + ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + warpgroup_idx * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); #pragma unroll for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 8; ++mma_d) { diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index de445225a..526a38229 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -479,7 +479,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st for (uint32_t j = 0; j < 4; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; DTypeO* o_partial_ptr = - partial_o + (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + partial_o + + ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); uint32_t o_smem_offset_w = get_swizzle_offset( (warp_idx_in_wg % 2) * 16 + 4 * j + lane_idx / 8, From 33c3e58158154fb71585c0454812f3675a1d373b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 27 Feb 2025 03:53:11 +0000 Subject: [PATCH 8/8] upd --- include/flashinfer/attention/mla_hopper.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index 526a38229..96014a0b2 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -494,6 +494,7 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st o_smem_offset_w += 64; } } + barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { @@ -503,8 +504,6 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st math::ptx_log2(d[j]) + float(m[j]); } } - - barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); } else { // write to final_o @@ -528,6 +527,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st o_smem_offset_w += 64; } } + barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); + if (final_lse) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { @@ -538,8 +539,6 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st } } } - - barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); } }