Skip to content

Commit

Permalink
The writing style of tail processing and the logic related to macro d…
Browse files Browse the repository at this point in the history
…efinitions have been optimized. (#5519)
  • Loading branch information
isky-cd authored Mar 28, 2024
1 parent e6496dd commit 934e31a
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 167 deletions.
2 changes: 1 addition & 1 deletion examples/inference/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ROOT=$(realpath $(dirname $0))
echo $ROOT
PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode="colossalai"
mode=$1

mkdir -p logs

Expand Down
23 changes: 8 additions & 15 deletions extensions/csrc/common/micros.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,14 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
if (HIGH_PRECISION) { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
} else { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
Expand Down
16 changes: 5 additions & 11 deletions extensions/csrc/common/mp_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,11 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};

template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;

template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};

template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
T>::type;
};

} // namespace common
Expand Down
133 changes: 60 additions & 73 deletions extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"

template<typename scalar_t, int VecSize>
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
Expand Down Expand Up @@ -55,17 +55,19 @@ __global__ void context_kv_cache_memcpy_kernel(
}

// tail process
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
if (!Aligned) {
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}

}
Expand Down Expand Up @@ -93,76 +95,61 @@ void apply_context_kv_cache_memcpy(

int vec_size = get_vec_size<scalar_t>(key);

bool aligned = true;
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
aligned = false;
}

int thread_nums = head_num * head_dim / vec_size;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));

switch (vec_size) {
case 1:
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 2:
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 4:
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
cu_seqlens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
); \
} while(0)

#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
do { \
switch (vec_size) { \
case 1: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", vec_size); \
break; \
} \
} while(0)


if (aligned) {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
}
else {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
}

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
124 changes: 57 additions & 67 deletions extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"

template<typename scalar_t, int VecSize>
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
Expand Down Expand Up @@ -45,17 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}

for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
if (!Aligned) {
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}

}
Expand All @@ -80,70 +82,58 @@ void apply_decode_kv_cache_memcpy(

int vec_size = get_vec_size<scalar_t>(key);

bool aligned = true;
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
aligned = false;
}

int thread_nums = head_num * head_dim / vec_size;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));

switch (vec_size) {
case 1:
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 2:
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 4:
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
decode_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
key_stride, \
value_stride, \
block_table_stride \
); \
} while(0)

#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \
do { \
switch (__vec_size) { \
case 1: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", __vec_size); \
break; \
} \
} while(0)

if (aligned) {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size);
}
else {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size);
}

AT_CUDA_CHECK(cudaGetLastError());
Expand Down

0 comments on commit 934e31a

Please sign in to comment.