From 04aca9e55bd91ea4dd8d1231aa66df7848b08f03 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 1 Apr 2024 13:47:14 +0800 Subject: [PATCH] [Inference/Kernel]Add get_cos_and_sin Kernel (#5528) * Add get_cos_and_sin kernel * fix code comments * fix code typos * merge common codes of get_cos_and_sin kernel. * Fixed a typo * Changed 'asset allclose' to 'assert equal'. --- .../modeling/models/nopadding_llama.py | 18 +- .../csrc/cuda/get_cos_and_sin_kernel.cu | 215 ++++++++++++++++++ extensions/csrc/cuda/pybind/inference.cpp | 14 +- extensions/inference/inference_ops_cuda.py | 1 + .../test_ops/cuda/test_get_cos_and_sin.py | 53 +++++ 5 files changed, 295 insertions(+), 6 deletions(-) create mode 100644 extensions/csrc/cuda/get_cos_and_sin_kernel.cu create mode 100644 tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 37a714c8312c..c5b61385f822 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -101,12 +101,22 @@ def llama_model_forward( use_cuda_kernel = False hidden_states = self.embed_tokens(input_tokens_ids) - if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2: - cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + if use_cuda_kernel: + if inputmetadata != torch.float32 and use_flash_attn2: + cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + + hidden_dim = self._cos_cached.size(-1) + total_length = hidden_states.size(0) + cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device) + sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device) + inference_ops.get_cos_and_sin( + self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts + ) + cos_sin = (cos, sin) + else: cu_seqlens = None - - cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) sm_scale = 1.0 / (inputmetadata.head_dim**0.5) diff --git a/extensions/csrc/cuda/get_cos_and_sin_kernel.cu b/extensions/csrc/cuda/get_cos_and_sin_kernel.cu new file mode 100644 index 000000000000..15aea740e6f9 --- /dev/null +++ b/extensions/csrc/cuda/get_cos_and_sin_kernel.cu @@ -0,0 +1,215 @@ +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" +#include "stdio.h" + +template +__device__ void apply_cos_and_sin_memcopy( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int head_dim, + const int dest_offset_id, + const int src_offset_id + ) { + + int begin_id = threadIdx.x * VecSize; + + for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){ + copy_vector(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id); + copy_vector(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id); + } + + if (!Aligned) { + for (; begin_id < head_dim; ++begin_id ) { + cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id]; + sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id]; + } + } +} + +template +__global__ void apply_get_context_cos_and_sin_kernel( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cumsum_lengths, + const int batch_size, + const int head_dim +) { + int token_id = blockIdx.x; + if ( token_id >= sequence_lengths[blockIdx.y] ) { + return ; + } + + int src_offset_id = token_id * head_dim; + int dest_offset_id = src_offset_id; + + if (blockIdx.y > 0) { + dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim; + } + + apply_cos_and_sin_memcopy( + cos, + sin, + cos_cache_ptr, + sin_cache_ptr, + sequence_lengths, + head_dim, + dest_offset_id, + src_offset_id + ); + +} + +template +__global__ void apply_get_decode_cos_and_sin_kernel( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int batch_size, + const int head_dim +) { + int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim; + int dest_offset_id = blockIdx.y * head_dim; + + apply_cos_and_sin_memcopy( + cos, + sin, + cos_cache_ptr, + sin_cache_ptr, + sequence_lengths, + head_dim, + dest_offset_id, + src_offset_id + ); +} + +template +void apply_get_cos_and_sin( + at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, + bool is_prompts +) { + int token_num = cos.size(0); + int head_dim = cos.size(1); + int batch_size = sequence_lengths.size(0); + + at::Tensor cumsum_lengths; + + int vec_size = get_vec_size(cos); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int block_size_y; + int block_size_x; + + if (is_prompts) { + block_size_y = batch_size; + block_size_x = max_seq_len_in_batch; + // TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on. + cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32); + } + else{ + block_size_y = batch_size; + block_size_x = 1; + } + + int thread_nums = (head_dim + vec_size - 1) / vec_size; + + dim3 grid(block_size_x, block_size_y); + dim3 block(std::min(thread_nums, 512)); + +#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + if (is_prompts){ \ + apply_get_context_cos_and_sin_kernel<<>>( \ + cos.data_ptr(), \ + sin.data_ptr(), \ + cos_cache.data_ptr(), \ + sin_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + cumsum_lengths.data_ptr(), \ + batch_size, \ + head_dim \ + ); \ + } \ + else { \ + apply_get_decode_cos_and_sin_kernel<<>>( \ + cos.data_ptr(), \ + sin.data_ptr(), \ + cos_cache.data_ptr(), \ + sin_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + batch_size, \ + head_dim \ + ); \ + } \ + } while(0) + +#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ + do { \ + switch (vec_size) { \ + case 1: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", vec_size); \ + break; \ + } \ + } while(0) + + if (aligned) { + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true); + } + else { + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void get_cos_and_sin( + at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, + bool is_prompts +) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + cos.scalar_type(), + "get_cos_and_sin", + apply_get_cos_and_sin( + cos_cache, + sin_cache, + cos, + sin, + sequence_lengths, + max_seq_len_in_batch, + is_prompts + );) +} diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 541146e3a60d..45745e6a3e29 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -51,6 +51,13 @@ void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] float epsilon); +void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, bool is_prompts); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); @@ -60,10 +67,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, - "performing Rotary Embedding-related calculations and KVCache Memcopy."); + "Performing Rotary Embedding-related calculations and KVCache Memcopy."); m.def("rotary_embedding", &rotary_embedding, - "performing Rotary Embedding-related calculations."); + "Performing Rotary Embedding-related calculations."); m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); @@ -72,4 +79,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm, "In-place fused Add and RMS Normalization."); + + m.def("get_cos_and_sin", &get_cos_and_sin, + "Get cos and sin from the cache."); } diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 4e0afc819c51..09ebfdabde88 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -16,6 +16,7 @@ def sources_files(self): "cuda/fused_rotary_emb_and_cache_kernel.cu", "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", + "cuda/get_cos_and_sin_kernel.cu", ] ] return ret diff --git a/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py b/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py new file mode 100644 index 000000000000..c632cfe302e7 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin + +inference_ops = InferenceOpsLoader().load() + + +def numpy_equal(x, y): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_equal(x_numpy, y_numpy) + + +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("MAX_SEQ_LEN", [64]) +@pytest.mark.parametrize("HEAD_DIM", [64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): + MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN + cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32) + + max_seq_len_in_batch = lengths.max() + + # prefill + cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + + cos = torch.zeros_like(cos_ref) + sin = torch.zeros_like(sin_ref) + + inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True) + + numpy_equal(cos, cos_ref) + numpy_equal(sin, sin_ref) + + # decoding + ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) + + cos = torch.zeros_like(ncos_ref) + sin = torch.zeros_like(nsin_ref) + + inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False) + numpy_equal(cos, ncos_ref) + numpy_equal(sin, nsin_ref) + + +if __name__ == "__main__": + test_get_cos_and_sin(16, 4096, 256, torch.float16)