Skip to content

Commit

Permalink
optimize flashdecodingattention
Browse files Browse the repository at this point in the history
  • Loading branch information
SunflowerAries committed Apr 19, 2024
1 parent 46d1d61 commit 100771a
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 53 deletions.
55 changes: 28 additions & 27 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,19 +592,20 @@ def forward(
block_tables,
high_precision,
)
# inference_ops.flash_decoding_attention(
# attn_output,
# query_states,
# k_cache,
# v_cache,
# sequence_lengths,
# block_tables,
# block_size,
# kv_seq_len,
# fd_inter_tensor.mid_output,
# fd_inter_tensor.mid_output_lse,
# sm_scale,
# )
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
Expand All @@ -626,20 +627,20 @@ def forward(
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down
169 changes: 169 additions & 0 deletions examples/inference/benchmark_ops/profile_flash_decoding_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import ctypes

import torch

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_vllm,
)

_cudart = ctypes.CDLL("libcudart.so")

try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")

inference_ops = InferenceOpsLoader().load()


def start():
# As shown at http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__PROFILER.html,
# the return value will unconditionally be 0. This check is just in case it changes in
# the future.
ret = _cudart.cudaProfilerStart()
if ret != 0:
raise Exception("cudaProfilerStart() returned %d" % ret)


def stop():
ret = _cudart.cudaProfilerStop()
if ret != 0:
raise Exception("cudaProfilerStop() returned %d" % ret)


# Triton benchmark plot attributions
configs = [
triton.testing.Benchmark(
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
x_vals=[2**i for i in range(3, 4)],
line_arg="provider",
line_vals=[
# "vllm_paged_decoding_attention",
"cuda_flash_decoding_attention",
],
line_names=[
# "vllm_paged_decoding_attention",
"cuda_flash_decoding_attention",
],
styles=[("red", "-")],
# styles=[("red", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"FlashDecodingAttention benchmarking results",
args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2},
)
]


def prepare_data(
BATCH_SIZE: int,
HEAD_SIZE: int,
NUM_ATTN_HEADS: int,
NUM_KV_HEADS: int,
MAX_SEQ_LEN: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
num_tokens = torch.sum(kv_lengths).item()

q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)

return q, k_unpad, v_unpad, kv_lengths


def benchmark_flash_decoding_attention(
provider: str,
BATCH_SIZE: int,
BLOCK_SIZE: int,
MAX_NUM_BLOCKS_PER_SEQ: int,
HEAD_SIZE: int,
KV_GROUP_NUM: int,
):
try:
from vllm._C import ops as vllm_ops
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")

dtype = torch.float16

NUM_ATTN_HEADS = 16

NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()

q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)

k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)

vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)

block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)

mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)

start()
alibi_slopes = None
vllm_ops.paged_attention_v1(
output,
q.squeeze(2),
vllm_k_cache,
vllm_v_cache,
NUM_KV_HEADS,
sm_scale,
block_tables,
kv_seq_lengths,
BLOCK_SIZE,
max_seq_len_across_batch,
alibi_slopes,
"auto",
)

inference_ops.flash_decoding_attention(
output,
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
sm_scale,
)

stop()
# ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)


if __name__ == "__main__":
benchmark_flash_decoding_attention("cuda_flash_decoding_attention", 16, 32, 8, 128, 2)
21 changes: 8 additions & 13 deletions extensions/csrc/cuda/attention/attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,14 @@ struct Qk_dot {
}
};

template <int NUM_WARPS, int NUM_THREADS_PER_TOKEN>
inline __device__ float block_max(float* red_smem, float max) {
int warp = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;

template <int NUM_WARPS>
inline __device__ float block_max(int warp, int lane, float* red_smem,
float max) {
// Perform reduction across the threads in the same warp to get the max value
// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the
// max value among every NUM_THREADS_PER_TOKEN threads.
#pragma unroll
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_TOKEN; mask >>= 1) {
for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) {
max = fmaxf(max, SHFL_XOR_SYNC(max, mask));
}

Expand All @@ -112,10 +110,8 @@ inline __device__ float block_max(float* red_smem, float max) {
// here we need another block_sum instead of using block_reduce
// since we need manage shared memory in a explicit way
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
int warp = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;

inline __device__ float block_sum(int warp, int lane, float* red_smem,
float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) {
Expand All @@ -141,10 +137,9 @@ inline __device__ float block_sum(float* red_smem, float sum) {

// here VecT is a vector of float, whose size is N
template <typename VecT, int NUM_WARPS, int NUM_THREADS_PER_GROUP, int N>
inline __device__ void block_sum(float* red_smem, VecT& acc) {
inline __device__ void block_sum(int warp, int lane, float* red_smem,
VecT& acc) {
float* acc_ptr = reinterpret_cast<float*>(&acc);
int warp = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;

#pragma unroll
for (int i = 0; i < N; i++) {
Expand Down
36 changes: 23 additions & 13 deletions extensions/csrc/cuda/flash_decoding_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ __global__ void flash_decoding_attention_kernel(
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
constexpr int NUM_ROUNDS_PER_BLOCK = BLOCK_SIZE * NUM_VECS_PER_TOKEN / WARP_STRIDE;
constexpr int NUM_THREAD_GROUPS = WARP_SIZE / NUM_THREADS_PER_TOKEN;

using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
Expand Down Expand Up @@ -106,37 +108,45 @@ __global__ void flash_decoding_attention_kernel(

scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
// each warp access a whole block

K_vec q_vecs[NUM_ROUNDS_PER_TOKEN];
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
q_vecs[i] = (reinterpret_cast<K_vec*>(q_shared_ptr))[(lane + i * WARP_SIZE) % NUM_VECS_PER_TOKEN];
}

for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
float qks[NUM_ROUNDS_PER_BLOCK];

#pragma unroll
for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) {
const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN;
for (int idx = lane, cnt = 0; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE, cnt += 1) {
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ idx * VEC_SIZE;

K_vec k_vecs[NUM_ROUNDS_PER_TOKEN];
K_vec q_vecs[NUM_ROUNDS_PER_TOKEN];

// we must calculate at least one row of hidden vectors
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
k_vecs[i] = (reinterpret_cast<const K_vec*>(k_ptr))[i * WARP_SIZE];
q_vecs[i] = (reinterpret_cast<K_vec*>(q_shared_ptr))[(idx + i * WARP_SIZE) % NUM_VECS_PER_TOKEN];
}

float qk = scale * Qk_dot<scalar_t, NUM_THREADS_PER_TOKEN>::dot(q_vecs, k_vecs);
qks[cnt] = scale * Qk_dot<scalar_t, NUM_THREADS_PER_TOKEN>::dot(q_vecs, k_vecs);
}

if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
#pragma unroll
for (int idx = thread_group_offset * NUM_THREAD_GROUPS + (lane / NUM_THREADS_PER_TOKEN), cnt = thread_group_offset; idx < BLOCK_SIZE; idx += WARP_SIZE, cnt += NUM_THREADS_PER_TOKEN) {
const int token_idx = block_idx * BLOCK_SIZE + idx;
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qks[cnt];
qk_max = mask ? qk_max : fmaxf(qk_max, qks[cnt]);
}
}

// there exists a __syncthreads within this function
qk_max = block_max<NUM_WARPS, NUM_THREADS_PER_TOKEN>(red_shared_mem, qk_max);
qk_max = block_max<NUM_WARPS>(warp_idx, lane, red_shared_mem, qk_max);

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand All @@ -146,7 +156,7 @@ __global__ void flash_decoding_attention_kernel(
exp_sum += val;
}

exp_sum = block_sum<NUM_WARPS>(&red_shared_mem[NUM_WARPS], exp_sum);
exp_sum = block_sum<NUM_WARPS>(warp_idx, lane, &red_shared_mem[NUM_WARPS], exp_sum);
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
Expand Down Expand Up @@ -199,7 +209,7 @@ __global__ void flash_decoding_attention_kernel(

#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
block_sum<Float_vec, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
block_sum<Float_vec, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(warp_idx, lane, out_shared_mem, accs[i]);
}

scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE;
Expand Down

0 comments on commit 100771a

Please sign in to comment.