Skip to content

Commit

Permalink
support v100 csrc
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Dec 7, 2023
1 parent 5eb2eae commit 6d92e6c
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 7 deletions.
6 changes: 5 additions & 1 deletion csrc/generation/dequant_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ std::vector<paddle::Tensor> LaunchDequantInt8(const paddle::Tensor& input,
"Only bfloat16, float16 and float32 are supported. ");

switch (data_type) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16:
return DispatchLaunchDequantInt8<paddle::DataType::BFLOAT16>(input, scale);
break;
#endif
case paddle::DataType::FLOAT16:
return DispatchLaunchDequantInt8<paddle::DataType::FLOAT16>(input, scale);
break;
Expand All @@ -145,8 +147,10 @@ std::vector<paddle::DataType> DequantInt8Dtype(const paddle::DataType& input_dty
paddle::DataType data_type;
if (dtype == "float32")
data_type = paddle::DataType::FLOAT32;
#if CUDA_VERSION >= 11000
else if (dtype == "bfloat16")
data_type = paddle::DataType::BFLOAT16;
#endif
else if (dtype == "float16")
data_type = paddle::DataType::FLOAT16;
else
Expand All @@ -163,4 +167,4 @@ PD_BUILD_OP(dequant_int8)
.Attrs({"dtype: std::string"})
.SetKernelFn(PD_KERNEL(DequantInt8))
.SetInferShapeFn(PD_INFER_SHAPE(DequantInt8Shape))
.SetInferDtypeFn(PD_INFER_DTYPE(DequantInt8Dtype));
.SetInferDtypeFn(PD_INFER_DTYPE(DequantInt8Dtype));
2 changes: 2 additions & 0 deletions csrc/generation/encode_rotary_qk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,13 @@ void RotaryQK(const paddle::Tensor& q,
const int32_t rotary_emb_dims,
bool use_neox) {
switch (q.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return LaunchRotaryQK<paddle::DataType::BFLOAT16>(
q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
);
}
#endif
case paddle::DataType::FLOAT16: {
return LaunchRotaryQK<paddle::DataType::FLOAT16>(
q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
Expand Down
6 changes: 3 additions & 3 deletions csrc/generation/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#pragma once

#include "paddle/extension.h"
#include <cub/cub.cuh>
#include <curand_kernel.h>

constexpr int kBlockSize = 256;
Expand Down Expand Up @@ -70,13 +69,14 @@ class PDTraits<paddle::DataType::FLOAT16> {
typedef half DataType;
typedef paddle::float16 data_t;
};

#if CUDA_VERSION >= 11000
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef __nv_bfloat16 DataType;
typedef paddle::bfloat16 data_t;
};
#endif

template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
Expand All @@ -100,4 +100,4 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
*addr_vec = vec;
}

constexpr int VEC_16B = 16;
constexpr int VEC_16B = 16;
2 changes: 2 additions & 0 deletions csrc/generation/qkv_transpose_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ std::vector<paddle::Tensor> QKVTransposeSplit(const paddle::Tensor& qkv,
int num_head,
int head_size) {
switch (qkv.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return qkv_transpose_split<paddle::DataType::BFLOAT16>(
qkv,
Expand All @@ -145,6 +146,7 @@ std::vector<paddle::Tensor> QKVTransposeSplit(const paddle::Tensor& qkv,
head_size
);
}
#endif
case paddle::DataType::FLOAT16: {
return qkv_transpose_split<paddle::DataType::FLOAT16>(
qkv,
Expand Down
10 changes: 7 additions & 3 deletions csrc/generation/quant_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
#include<stdio.h>
#include<algorithm>
#include<cuda_fp16.h>
#if CUDA_VERSION >= 11000
#include<cuda_bf16.h>

#endif

constexpr int DequantKernelVecSize = 4;

Expand Down Expand Up @@ -52,11 +53,12 @@ __forceinline__ __device__ half add_mul<half>(half a, half b, half c) {
return __hmul(__hadd(a, b), c);
}

#if CUDA_VERSION >= 11000
template<>
__forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
return __hmul(__hadd(a, b), c);
}

#endif


template <typename data_t>
Expand Down Expand Up @@ -210,11 +212,13 @@ std::vector<paddle::Tensor> QuantInt8(const paddle::Tensor& input,
float min_bound) {
// printf("#### quant int8 scale:%f \n",scale);
switch (input.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return LaunchQuantInt8<paddle::DataType::BFLOAT16>(
input, shift, smooth, scale, round_type, max_bound, min_bound
);
}
#endif
case paddle::DataType::FLOAT16: {
return LaunchQuantInt8<paddle::DataType::FLOAT16>(
input, shift, smooth, scale, round_type, max_bound, min_bound
Expand Down Expand Up @@ -256,4 +260,4 @@ PD_BUILD_OP(quant_int8)
.Attrs({"scale: float","round_type: int","max_bound: float", "min_bound: float"})
.SetKernelFn(PD_KERNEL(QuantInt8))
.SetInferShapeFn(PD_INFER_SHAPE(QuantInt8Shape))
.SetInferDtypeFn(PD_INFER_DTYPE(QuantInt8Dtype));
.SetInferDtypeFn(PD_INFER_DTYPE(QuantInt8Dtype));
2 changes: 2 additions & 0 deletions csrc/generation/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ std::vector<paddle::Tensor> RebuildPadding(const paddle::Tensor& tmp_out,
const paddle::Tensor& seq_lens,
const paddle::Tensor& input_ids) {
switch (tmp_out.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>(
tmp_out,
Expand All @@ -117,6 +118,7 @@ std::vector<paddle::Tensor> RebuildPadding(const paddle::Tensor& tmp_out,
input_ids
);
}
#endif
case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>(
tmp_out,
Expand Down
2 changes: 2 additions & 0 deletions csrc/generation/set_alibi_mask_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ std::vector<paddle::Tensor> SetMaskValue(const paddle::Tensor& input_data,
const paddle::Tensor& alibi_slopes,
const paddle::Tensor& tgt_pos) {
switch (input_data.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return set_mask_value<paddle::DataType::BFLOAT16>(
input_data,
Expand All @@ -85,6 +86,7 @@ std::vector<paddle::Tensor> SetMaskValue(const paddle::Tensor& input_data,
tgt_pos
);
}
#endif
case paddle::DataType::FLOAT16: {
return set_mask_value<paddle::DataType::FLOAT16>(
input_data,
Expand Down
5 changes: 5 additions & 0 deletions csrc/generation/set_mask_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/extension.h"


template <paddle::DataType D>
class PDTraits;

Expand All @@ -31,12 +32,14 @@ public:
typedef paddle::float16 data_t;
};

#if CUDA_VERSION >= 11000
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef __nv_bfloat16 DataType;
typedef paddle::bfloat16 data_t;
};
#endif

template <typename T>
__global__ void set_value_by_id(const int *seq_lens, const bool *stop_flags, T *output_data, int *sequence_lengths, int bs, int length) {
Expand Down Expand Up @@ -77,13 +80,15 @@ std::vector<paddle::Tensor> set_mask_value(const paddle::Tensor& input_data, con

std::vector<paddle::Tensor> SetMaskValue(const paddle::Tensor& input_data, const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens) {
switch (input_data.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return set_mask_value<paddle::DataType::BFLOAT16>(
input_data,
stop_flags,
seq_lens
);
}
#endif
case paddle::DataType::FLOAT16: {
return set_mask_value<paddle::DataType::FLOAT16>(
input_data,
Expand Down
2 changes: 2 additions & 0 deletions csrc/generation/token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ std::vector<paddle::Tensor> TokenPenaltyMultiScores(const paddle::Tensor& pre_id
const paddle::Tensor& eos_token_id) {

switch (logits.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::BFLOAT16>(
pre_ids,
Expand All @@ -168,6 +169,7 @@ std::vector<paddle::Tensor> TokenPenaltyMultiScores(const paddle::Tensor& pre_id
eos_token_id
);
}
#endif
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
pre_ids,
Expand Down
2 changes: 2 additions & 0 deletions csrc/generation/transpose_removing_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,15 @@ std::vector<paddle::Tensor> ApplyTransposeRemovingPadding(const paddle::Tensor&
const paddle::Tensor& seq_lens,
const paddle::Tensor& padding_offset) {
switch (input.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return apply_transpose_remove_padding<paddle::DataType::BFLOAT16>(
input,
seq_lens,
padding_offset
);
}
#endif
case paddle::DataType::FLOAT16: {
return apply_transpose_remove_padding<paddle::DataType::FLOAT16>(
input,
Expand Down
2 changes: 2 additions & 0 deletions csrc/generation/write_cache_kv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,13 @@ void WriteCacheKV(const paddle::Tensor& input_k,
const paddle::Tensor& cache_kv,
const paddle::Tensor& sequence_lengths_shape) {
switch (cache_kv.type()) {
#if CUDA_VERSION >= 11000
case paddle::DataType::BFLOAT16: {
return LaunchWriteCacheKV<paddle::DataType::BFLOAT16>(
input_k, input_v, cache_kv, sequence_lengths_shape
);
}
#endif
case paddle::DataType::FLOAT16: {
return LaunchWriteCacheKV<paddle::DataType::FLOAT16>(
input_k, input_v, cache_kv, sequence_lengths_shape
Expand Down

0 comments on commit 6d92e6c

Please sign in to comment.