diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index 4b4f9715ce14..4b015757ef0d 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -2,7 +2,7 @@ ROOT=$(realpath $(dirname $0)) echo $ROOT PY_SCRIPT=${ROOT}/benchmark_llama.py GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) -mode="colossalai" +mode=$1 mkdir -p logs diff --git a/extensions/csrc/common/micros.h b/extensions/csrc/common/micros.h index 5400a6dc1951..12cd78046b6a 100644 --- a/extensions/csrc/common/micros.h +++ b/extensions/csrc/common/micros.h @@ -56,21 +56,14 @@ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ - TYPE, NAME, ...) \ - switch (HIGH_PRECISION) { \ - case false: { \ - const bool high_precision = false; \ - DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ - break; \ - } \ - case true: { \ - const bool high_precision = true; \ - DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ - break; \ - } \ - default: \ - AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \ +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ + TYPE, NAME, ...) \ + if (HIGH_PRECISION) { \ + const bool high_precision = true; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + } else { \ + const bool high_precision = false; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h index 77de7c12a97d..5275732194ab 100644 --- a/extensions/csrc/common/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -27,17 +27,11 @@ struct MPTypeTrait { using Type = float; }; -template -struct ScalarTypeTrait; - -template -struct ScalarTypeTrait { - using Type = typename MPTypeTrait::Type; -}; - -template -struct ScalarTypeTrait { - using Type = T; +template +struct ScalarTypeTrait { + using Type = + typename std::conditional::Type, + T>::type; }; } // namespace common diff --git a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu index 3f6adc018b41..3300fad47796 100644 --- a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu @@ -4,7 +4,7 @@ #include "utils/vector_copy_utils.h" #include "../common/micros.h" -template +template __global__ void context_kv_cache_memcpy_kernel( const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, @@ -55,17 +55,19 @@ __global__ void context_kv_cache_memcpy_kernel( } // tail process - for (; i < hidden_size; ++i ) { - head_id = i / head_dim; - head_offset = i % head_dim; - key_src_id = total_token_id * key_stride + i; - value_src_id = total_token_id * value_stride + i; - target_id = block_id * hidden_size * block_size - + head_id * block_size * head_dim - + block_offset * head_dim + head_offset; - - key_cache[target_id] = key[key_src_id]; - value_cache[target_id] = value[value_src_id]; + if (!Aligned) { + for (; i < hidden_size; ++i ) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } } } @@ -93,76 +95,61 @@ void apply_context_kv_cache_memcpy( int vec_size = get_vec_size(key); + bool aligned = true; if (head_dim % vec_size != 0) { - // Disable vectorized loading optimization when head_dim is not divisible by VecSize. - vec_size = 1; + aligned = false; } int thread_nums = head_num * head_dim / vec_size; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(max_seq_len_in_batch, batch_size); dim3 block(std::min(thread_nums, 512)); - switch (vec_size) { - case 1: - context_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - cu_seqlens.data_ptr(), - block_tables.data_ptr(), - head_num, - head_dim, - block_size, - batch_size, - block_table_stride, - key_stride, - value_stride - ); - break; - case 2: - context_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - cu_seqlens.data_ptr(), - block_tables.data_ptr(), - head_num, - head_dim, - block_size, - batch_size, - block_table_stride, - key_stride, - value_stride - ); - break; - case 4: - context_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - cu_seqlens.data_ptr(), - block_tables.data_ptr(), - head_num, - head_dim, - block_size, - batch_size, - block_table_stride, - key_stride, - value_stride - ); - break; - default: - AT_ERROR("Unsupported vectorized size ", vec_size); - break; +#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + context_kv_cache_memcpy_kernel<<>>( \ + key.data_ptr(), \ + value.data_ptr(), \ + key_cache.data_ptr(), \ + value_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + cu_seqlens.data_ptr(), \ + block_tables.data_ptr(), \ + head_num, \ + head_dim, \ + block_size, \ + batch_size, \ + block_table_stride, \ + key_stride, \ + value_stride \ + ); \ + } while(0) + +#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ + do { \ + switch (vec_size) { \ + case 1: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", vec_size); \ + break; \ + } \ + } while(0) + + + if (aligned) { + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true); + } + else { + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false); } AT_CUDA_CHECK(cudaGetLastError()); diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 08889b23636c..3fcceac6b942 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -4,7 +4,7 @@ #include "utils/vector_copy_utils.h" #include "../common/micros.h" -template +template __global__ void decode_kv_cache_memcpy_kernel( const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, @@ -45,17 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel( copy_vector(value_cache + target_id, value + value_src_id); } - for (; i < hidden_size; ++i ) { - const int head_id = i / head_dim; - const int head_offset = i % head_dim; - const int64_t key_src_id = seq_id * key_stride + i; - const int64_t value_src_id = seq_id * value_stride + i; - const int64_t target_id = block_id * hidden_size * block_size - + head_id * block_size * head_dim - + block_offset * head_dim + head_offset; - - key_cache[target_id] = key[key_src_id]; - value_cache[target_id] = value[value_src_id]; + if (!Aligned) { + for (; i < hidden_size; ++i ) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } } } @@ -80,70 +82,58 @@ void apply_decode_kv_cache_memcpy( int vec_size = get_vec_size(key); + bool aligned = true; if (head_dim % vec_size != 0) { - // Disable vectorized loading optimization when head_dim is not divisible by VecSize. - vec_size = 1; + aligned = false; } int thread_nums = head_num * head_dim / vec_size; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(num_tokens); dim3 block(std::min(thread_nums, 512)); - switch (vec_size) { - case 1: - decode_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - head_num, - head_dim, - block_size, - key_stride, - value_stride, - block_table_stride - ); - break; - case 2: - decode_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - head_num, - head_dim, - block_size, - key_stride, - value_stride, - block_table_stride - ); - break; - case 4: - decode_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - head_num, - head_dim, - block_size, - key_stride, - value_stride, - block_table_stride - ); - break; - default: - AT_ERROR("Unsupported vectorized size ", vec_size); - break; +#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + decode_kv_cache_memcpy_kernel<<>>( \ + key.data_ptr(), \ + value.data_ptr(), \ + key_cache.data_ptr(), \ + value_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + block_tables.data_ptr(), \ + head_num, \ + head_dim, \ + block_size, \ + key_stride, \ + value_stride, \ + block_table_stride \ + ); \ + } while(0) + +#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \ + do { \ + switch (__vec_size) { \ + case 1: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", __vec_size); \ + break; \ + } \ + } while(0) + + if (aligned) { + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size); + } + else { + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size); } AT_CUDA_CHECK(cudaGetLastError());