forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
46d1d61
commit 100771a
Showing
4 changed files
with
228 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
examples/inference/benchmark_ops/profile_flash_decoding_attention.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import ctypes | ||
|
||
import torch | ||
|
||
from colossalai.kernel.kernel_loader import InferenceOpsLoader | ||
from colossalai.utils import get_current_device | ||
from tests.test_infer.test_ops.triton.kernel_utils import ( | ||
generate_caches_and_block_tables_v2, | ||
generate_caches_and_block_tables_vllm, | ||
) | ||
|
||
_cudart = ctypes.CDLL("libcudart.so") | ||
|
||
try: | ||
import triton # noqa | ||
except ImportError: | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
inference_ops = InferenceOpsLoader().load() | ||
|
||
|
||
def start(): | ||
# As shown at http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__PROFILER.html, | ||
# the return value will unconditionally be 0. This check is just in case it changes in | ||
# the future. | ||
ret = _cudart.cudaProfilerStart() | ||
if ret != 0: | ||
raise Exception("cudaProfilerStart() returned %d" % ret) | ||
|
||
|
||
def stop(): | ||
ret = _cudart.cudaProfilerStop() | ||
if ret != 0: | ||
raise Exception("cudaProfilerStop() returned %d" % ret) | ||
|
||
|
||
# Triton benchmark plot attributions | ||
configs = [ | ||
triton.testing.Benchmark( | ||
x_names=["MAX_NUM_BLOCKS_PER_SEQ"], | ||
x_vals=[2**i for i in range(3, 4)], | ||
line_arg="provider", | ||
line_vals=[ | ||
# "vllm_paged_decoding_attention", | ||
"cuda_flash_decoding_attention", | ||
], | ||
line_names=[ | ||
# "vllm_paged_decoding_attention", | ||
"cuda_flash_decoding_attention", | ||
], | ||
styles=[("red", "-")], | ||
# styles=[("red", "-"), ("yellow", "-")], | ||
ylabel="ms", | ||
plot_name=f"FlashDecodingAttention benchmarking results", | ||
args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2}, | ||
) | ||
] | ||
|
||
|
||
def prepare_data( | ||
BATCH_SIZE: int, | ||
HEAD_SIZE: int, | ||
NUM_ATTN_HEADS: int, | ||
NUM_KV_HEADS: int, | ||
MAX_SEQ_LEN: int, | ||
dtype=torch.float16, | ||
device="cuda", | ||
): | ||
# Use the provided maximum sequence length for each sequence when testing with teh same context length, | ||
# otherwise generate random context lengths. | ||
# returns | ||
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE] | ||
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE] | ||
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device) | ||
num_tokens = torch.sum(kv_lengths).item() | ||
|
||
q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE) | ||
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) | ||
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE) | ||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) | ||
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2) | ||
|
||
return q, k_unpad, v_unpad, kv_lengths | ||
|
||
|
||
def benchmark_flash_decoding_attention( | ||
provider: str, | ||
BATCH_SIZE: int, | ||
BLOCK_SIZE: int, | ||
MAX_NUM_BLOCKS_PER_SEQ: int, | ||
HEAD_SIZE: int, | ||
KV_GROUP_NUM: int, | ||
): | ||
try: | ||
from vllm._C import ops as vllm_ops | ||
except ImportError: | ||
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") | ||
|
||
dtype = torch.float16 | ||
|
||
NUM_ATTN_HEADS = 16 | ||
|
||
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM | ||
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." | ||
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ | ||
device = get_current_device() | ||
|
||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( | ||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device | ||
) | ||
|
||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( | ||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device | ||
) | ||
|
||
vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm( | ||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device | ||
) | ||
|
||
block_tables = block_tables.to(device=device) | ||
max_seq_len_across_batch = kv_seq_lengths.max().item() | ||
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE | ||
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) | ||
sm_scale = 1.0 / (HEAD_SIZE**0.5) | ||
|
||
mid_output = torch.empty( | ||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device | ||
) | ||
mid_output_lse = torch.empty( | ||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device | ||
) | ||
|
||
start() | ||
alibi_slopes = None | ||
vllm_ops.paged_attention_v1( | ||
output, | ||
q.squeeze(2), | ||
vllm_k_cache, | ||
vllm_v_cache, | ||
NUM_KV_HEADS, | ||
sm_scale, | ||
block_tables, | ||
kv_seq_lengths, | ||
BLOCK_SIZE, | ||
max_seq_len_across_batch, | ||
alibi_slopes, | ||
"auto", | ||
) | ||
|
||
inference_ops.flash_decoding_attention( | ||
output, | ||
q.squeeze(2), | ||
k_cache, | ||
v_cache, | ||
kv_seq_lengths, | ||
block_tables, | ||
BLOCK_SIZE, | ||
max_seq_len_across_batch, | ||
mid_output, | ||
mid_output_lse, | ||
sm_scale, | ||
) | ||
|
||
stop() | ||
# ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | ||
|
||
|
||
if __name__ == "__main__": | ||
benchmark_flash_decoding_attention("cuda_flash_decoding_attention", 16, 32, 8, 128, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters