diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index 9e624ac8..af44cb34 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -95,7 +95,7 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int params.final_lse = maybe_lse.has_value() ? static_cast(maybe_lse->data_ptr()) : nullptr; params.partial_o = - GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params.partial_lse = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); diff --git a/csrc/batch_mla_sm90_run.cu b/csrc/batch_mla_sm90_run.cu index 44fb1907..200d8bf6 100644 --- a/csrc/batch_mla_sm90_run.cu +++ b/csrc/batch_mla_sm90_run.cu @@ -96,7 +96,7 @@ void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer, params.final_lse = maybe_lse.has_value() ? static_cast(maybe_lse->data_ptr()) : nullptr; params.partial_o = - GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params.partial_lse = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index 171bdef7..fc0cf239 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -620,9 +620,9 @@ __device__ void DevicePersistentMergeStates( typename KTraits::IdType* merge_packed_offset_end, typename KTraits::IdType* merge_partial_packed_offset_start, typename KTraits::IdType* merge_partial_packed_offset_end, - typename KTraits::IdType* merge_partial_stride, float* partial_o, float* partial_lse, - typename KTraits::DTypeO* final_o, float* final_lse, const uint32_t o_stride_n, - const uint32_t o_stride_h, const uint_fastdiv& num_heads) { + typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o, + float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse, + const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) { constexpr uint32_t VEC_SIZE = 4; // partial o has data type float constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE; constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW; @@ -645,8 +645,8 @@ __device__ void DevicePersistentMergeStates( partial_packed_offset < partial_offset_end; partial_packed_offset += stride) { vec_t o_partial; float lse_partial; - o_partial.load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV + - (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE); + o_partial.cast_load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV + + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE); lse_partial = partial_lse[partial_packed_offset]; st.merge(o_partial, lse_partial, 1); } @@ -662,10 +662,11 @@ __device__ void DevicePersistentMergeStates( template __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeO* final_o, float* final_lse, - float* partial_o, float* partial_lse, float (*o_frag)[8], - typename KTraits::DTypeQKAccum* m, float* d, - const uint32_t o_stride_n, const uint32_t o_stride_h, - const uint32_t q_len, const uint32_t packed_offset, + typename KTraits::DTypeO* partial_o, float* partial_lse, + float (*o_frag)[8], typename KTraits::DTypeQKAccum* m, + float* d, const uint32_t o_stride_n, + const uint32_t o_stride_h, const uint32_t q_len, + const uint32_t packed_offset, const uint_fastdiv& num_heads) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV; @@ -673,6 +674,26 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st constexpr uint32_t UPCAST_STRIDE_FINAL_O = KTraits::UPCAST_STRIDE_FINAL_O; const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y; smem_t o_smem(smem_storage->o_smem); +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) { + uint32_t o_frag_f16[8 / 2]; + vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]); +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( + warp_idx_in_wg * 16 + lane_idx % 16, + warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16); + o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( + warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2); + ((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; + ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = + o_frag_f16[3]; +#endif + } if (partial_o != nullptr) { // write to partial_o @@ -685,24 +706,27 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st } } + // step 1. smem to gmem + uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( + warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8); #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; + for (uint32_t j = 0; j < 4; ++j) { + uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; + DTypeO* o_partial_ptr = + partial_o + + ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + warpgroup_idx * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 8; ++mma_d) { if (q_idx < q_len) { - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[mma_d][j * 2]); - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + 8 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[mma_d][4 + j * 2]); + o_smem.template store_128b(o_smem_offset_w, o_partial_ptr); } + o_partial_ptr += 8 * upcast_size(); + o_smem_offset_w = o_smem.template advance_offset_by_column<8>(o_smem_offset_w, mma_d); } + o_smem_offset_w = + o_smem.template advance_offset_by_row<4, UPCAST_STRIDE_FINAL_O>(o_smem_offset_w) - + NUM_MMA_D_CKV; } } else { // write to final_o @@ -718,28 +742,6 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st } } - // step 0. rmem to smem -#pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) { - uint32_t o_frag_f16[8 / 2]; - vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]); -#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( - warp_idx_in_wg * 16 + lane_idx % 16, - warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16); - o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); -#else - uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( - warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2); - ((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; - ((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = - o_frag_f16[1]; - ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; - ((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) + - 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] = o_frag_f16[3]; -#endif - } - // step 1. smem to gmem uint32_t o_smem_offset_w = o_smem.template get_permuted_offset( warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8); @@ -794,7 +796,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe DTypeKV* ckv = params.ckv; DTypeKV* kpe = params.kpe; IdType* kv_indices = params.kv_indices; - float* partial_o = params.partial_o; + DTypeO* partial_o = params.partial_o; float* partial_lse = params.partial_lse; DTypeO* final_o = params.final_o; float* final_lse = params.final_lse; diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index 79d65db8..96014a0b 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -441,11 +441,14 @@ __device__ __forceinline__ void normalize_d_(typename KTraits::SharedStorage* sm } template -__device__ __forceinline__ void write_o( - typename KTraits::SharedStorage* smem_storage, const uint32_t stage_counter, - typename KTraits::DTypeO* final_o, float* final_lse, float* partial_o, float* partial_lse, - float(*o_frag), float* m, float* d, const uint32_t o_stride_n, const uint32_t o_stride_h, - const uint32_t q_len, const uint32_t packed_offset, const uint_fastdiv& num_heads) { +__device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage, + const uint32_t stage_counter, + typename KTraits::DTypeO* final_o, float* final_lse, + typename KTraits::DTypeO* partial_o, float* partial_lse, + float(*o_frag), float* m, float* d, + const uint32_t o_stride_n, const uint32_t o_stride_h, + const uint32_t q_len, const uint32_t packed_offset, + const uint_fastdiv& num_heads) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV; constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV; @@ -457,61 +460,52 @@ __device__ __forceinline__ void write_o( o_smem[0] = smem_storage->kv_o_smem[stage_counter % KTraits::NUM_STAGES].o; o_smem[1] = smem_storage->kv_o_smem[(stage_counter + 1) % KTraits::NUM_STAGES].o; + // step 0. rmem to smem +#pragma unroll + for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) { + uint32_t o_frag_f16[8 / 2]; + vec_cast::cast<8>((DTypeO*)o_frag_f16, &o_frag[k * 8]); + uint32_t o_smem_offset_w = get_swizzle_offset( + (warp_idx_in_wg % 2) * 16 + lane_idx % 16, + (warp_group_idx - 1) * NUM_MMA_D_CKV + k * 2 + lane_idx / 16); + o_smem[warp_idx_in_wg / 2].template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); + } + if (partial_o != nullptr) { // NOTE(Zihao): o_smem is not used if write to partial_o, and we can avoid the barrier - barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); // write to partial_o + #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; - if (lane_idx % 4 == 0 && q_idx < q_len) { - partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = - math::ptx_log2(d[j]) + float(m[j]); + for (uint32_t j = 0; j < 4; ++j) { + uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads; + DTypeO* o_partial_ptr = + partial_o + + ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV + + (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size(); + uint32_t o_smem_offset_w = get_swizzle_offset( + (warp_idx_in_wg % 2) * 16 + 4 * j + lane_idx / 8, + (warp_group_idx - 1) * NUM_MMA_D_CKV + lane_idx % 8); +#pragma unroll + for (uint32_t k = 0; k < HEAD_DIM_CKV / 128; ++k) { + if (q_idx < q_len) { + o_smem[warp_idx_in_wg / 2].template store_128b(o_smem_offset_w, o_partial_ptr); + } + o_partial_ptr += 8 * upcast_size(); + o_smem_offset_w += 64; } } + barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads; -#pragma unroll - for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) { - if (q_idx < q_len) { - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + k * 16 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[k * 8 + j * 2]); - *reinterpret_cast( - partial_o + - ((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV + - (warp_group_idx - 1) * (HEAD_DIM_CKV / 2) + k * 16 + 8 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[k * 8 + 4 + j * 2]); - } + if (lane_idx % 4 == 0 && q_idx < q_len) { + partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] = + math::ptx_log2(d[j]) + float(m[j]); } } } else { // write to final_o - if (final_lse) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q, r; - num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r); - if (lane_idx % 4 == 0 && q < q_len) { - final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]); - } - } - } - - // step 0. rmem to smem -#pragma unroll - for (uint32_t k = 0; k < HEAD_DIM_CKV / 32; ++k) { - uint32_t o_frag_f16[8 / 2]; - vec_cast::cast<8>((DTypeO*)o_frag_f16, &o_frag[k * 8]); - uint32_t o_smem_offset_w = get_swizzle_offset( - (warp_idx_in_wg % 2) * 16 + lane_idx % 16, - (warp_group_idx - 1) * NUM_MMA_D_CKV + k * 2 + lane_idx / 16); - o_smem[warp_idx_in_wg / 2].template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); - } // step 1. smem to gmem #pragma unroll @@ -534,6 +528,17 @@ __device__ __forceinline__ void write_o( } } barrier_arrive(KTraits::NUM_COPY_THREADS + KTraits::NUM_MMA_THREADS, NamedBarriers::kBarrierO); + + if (final_lse) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r); + if (lane_idx % 4 == 0 && q < q_len) { + final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]); + } + } + } } } @@ -586,7 +591,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop DTypeKV* ckv = params.ckv; DTypeKV* kpe = params.kpe; IdType* kv_indices = params.kv_indices; - float* partial_o = params.partial_o; + DTypeO* partial_o = params.partial_o; float* partial_lse = params.partial_lse; DTypeO* final_o = params.final_o; float* final_lse = params.final_lse; diff --git a/include/flashinfer/attention/mla_params.cuh b/include/flashinfer/attention/mla_params.cuh index 9b63e340..cad1acff 100644 --- a/include/flashinfer/attention/mla_params.cuh +++ b/include/flashinfer/attention/mla_params.cuh @@ -32,7 +32,7 @@ struct MLAParams { DTypeQ* q_pe; DTypeKV* ckv; DTypeKV* kpe; - float* partial_o; + DTypeO* partial_o; float* partial_lse; DTypeO* final_o; float* final_lse; diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 30046f41..ae943a71 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1104,7 +1104,21 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by total_kv_lens += effective_kv_len; } } - int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 256L) * 256L; + + auto f = [](int x) { + if (x <= 8) { + return 32; + } else if (x <= 16) { + return 64; + } else if (x <= 32) { + return 128; + } else if (x <= 64) { + return 192; + } + return ceil_div(x, 256) * 256; + }; + + int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L)); // step 1. load-balancing scheduling algorithm MinHeap cluster_cost_heap(num_clusters); @@ -1295,9 +1309,10 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, cudaMemcpyHostToDevice, stream)); + constexpr size_t sizeof_dtype_o = 2; AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.partial_o_offset = float_allocator.aligned_alloc_offset( - 2 * num_clusters * cluster_tile_q * sizeof(float) * head_dim_o, 16, "mla_partial_o"); + 2 * num_clusters * cluster_tile_q * sizeof_dtype_o * head_dim_o, 16, "mla_partial_o"); plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset( 2 * num_clusters * cluster_tile_q * sizeof(float), 16, "mla_partial_lse");