From 6d92e6c963dfe36b1d8af209ff11f2b4b359f4f6 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 7 Dec 2023 08:37:53 +0000 Subject: [PATCH] support v100 csrc --- csrc/generation/dequant_int8.cu | 6 +++++- csrc/generation/encode_rotary_qk.cu | 2 ++ csrc/generation/helper.h | 6 +++--- csrc/generation/qkv_transpose_split.cu | 2 ++ csrc/generation/quant_int8.cu | 10 +++++++--- csrc/generation/rebuild_padding.cu | 2 ++ csrc/generation/set_alibi_mask_value.cu | 2 ++ csrc/generation/set_mask_value.cu | 5 +++++ csrc/generation/token_penalty_multi_scores.cu | 2 ++ csrc/generation/transpose_removing_padding.cu | 2 ++ csrc/generation/write_cache_kv.cu | 2 ++ 11 files changed, 34 insertions(+), 7 deletions(-) diff --git a/csrc/generation/dequant_int8.cu b/csrc/generation/dequant_int8.cu index ddbe7b32df88..9976edf003c8 100644 --- a/csrc/generation/dequant_int8.cu +++ b/csrc/generation/dequant_int8.cu @@ -116,9 +116,11 @@ std::vector LaunchDequantInt8(const paddle::Tensor& input, "Only bfloat16, float16 and float32 are supported. "); switch (data_type) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: return DispatchLaunchDequantInt8(input, scale); break; +#endif case paddle::DataType::FLOAT16: return DispatchLaunchDequantInt8(input, scale); break; @@ -145,8 +147,10 @@ std::vector DequantInt8Dtype(const paddle::DataType& input_dty paddle::DataType data_type; if (dtype == "float32") data_type = paddle::DataType::FLOAT32; +#if CUDA_VERSION >= 11000 else if (dtype == "bfloat16") data_type = paddle::DataType::BFLOAT16; +#endif else if (dtype == "float16") data_type = paddle::DataType::FLOAT16; else @@ -163,4 +167,4 @@ PD_BUILD_OP(dequant_int8) .Attrs({"dtype: std::string"}) .SetKernelFn(PD_KERNEL(DequantInt8)) .SetInferShapeFn(PD_INFER_SHAPE(DequantInt8Shape)) - .SetInferDtypeFn(PD_INFER_DTYPE(DequantInt8Dtype)); \ No newline at end of file + .SetInferDtypeFn(PD_INFER_DTYPE(DequantInt8Dtype)); diff --git a/csrc/generation/encode_rotary_qk.cu b/csrc/generation/encode_rotary_qk.cu index d5f55a172592..6c91ea290472 100644 --- a/csrc/generation/encode_rotary_qk.cu +++ b/csrc/generation/encode_rotary_qk.cu @@ -193,11 +193,13 @@ void RotaryQK(const paddle::Tensor& q, const int32_t rotary_emb_dims, bool use_neox) { switch (q.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return LaunchRotaryQK( q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox ); } +#endif case paddle::DataType::FLOAT16: { return LaunchRotaryQK( q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox diff --git a/csrc/generation/helper.h b/csrc/generation/helper.h index 4a74709aecae..a3a59d0a2d34 100644 --- a/csrc/generation/helper.h +++ b/csrc/generation/helper.h @@ -15,7 +15,6 @@ #pragma once #include "paddle/extension.h" -#include #include constexpr int kBlockSize = 256; @@ -70,13 +69,14 @@ class PDTraits { typedef half DataType; typedef paddle::float16 data_t; }; - +#if CUDA_VERSION >= 11000 template <> class PDTraits { public: typedef __nv_bfloat16 DataType; typedef paddle::bfloat16 data_t; }; +#endif template struct alignas(sizeof(T) * Size) AlignedVector { @@ -100,4 +100,4 @@ HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { *addr_vec = vec; } -constexpr int VEC_16B = 16; \ No newline at end of file +constexpr int VEC_16B = 16; diff --git a/csrc/generation/qkv_transpose_split.cu b/csrc/generation/qkv_transpose_split.cu index d283ac81913e..5776fdd14604 100644 --- a/csrc/generation/qkv_transpose_split.cu +++ b/csrc/generation/qkv_transpose_split.cu @@ -135,6 +135,7 @@ std::vector QKVTransposeSplit(const paddle::Tensor& qkv, int num_head, int head_size) { switch (qkv.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return qkv_transpose_split( qkv, @@ -145,6 +146,7 @@ std::vector QKVTransposeSplit(const paddle::Tensor& qkv, head_size ); } +#endif case paddle::DataType::FLOAT16: { return qkv_transpose_split( qkv, diff --git a/csrc/generation/quant_int8.cu b/csrc/generation/quant_int8.cu index 1e76f3563ae9..0ed8ee373fa5 100644 --- a/csrc/generation/quant_int8.cu +++ b/csrc/generation/quant_int8.cu @@ -23,8 +23,9 @@ #include #include #include +#if CUDA_VERSION >= 11000 #include - +#endif constexpr int DequantKernelVecSize = 4; @@ -52,11 +53,12 @@ __forceinline__ __device__ half add_mul(half a, half b, half c) { return __hmul(__hadd(a, b), c); } +#if CUDA_VERSION >= 11000 template<> __forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { return __hmul(__hadd(a, b), c); } - +#endif template @@ -210,11 +212,13 @@ std::vector QuantInt8(const paddle::Tensor& input, float min_bound) { // printf("#### quant int8 scale:%f \n",scale); switch (input.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return LaunchQuantInt8( input, shift, smooth, scale, round_type, max_bound, min_bound ); } +#endif case paddle::DataType::FLOAT16: { return LaunchQuantInt8( input, shift, smooth, scale, round_type, max_bound, min_bound @@ -256,4 +260,4 @@ PD_BUILD_OP(quant_int8) .Attrs({"scale: float","round_type: int","max_bound: float", "min_bound: float"}) .SetKernelFn(PD_KERNEL(QuantInt8)) .SetInferShapeFn(PD_INFER_SHAPE(QuantInt8Shape)) - .SetInferDtypeFn(PD_INFER_DTYPE(QuantInt8Dtype)); \ No newline at end of file + .SetInferDtypeFn(PD_INFER_DTYPE(QuantInt8Dtype)); diff --git a/csrc/generation/rebuild_padding.cu b/csrc/generation/rebuild_padding.cu index 3c8dcc9be47f..4652a081fd42 100644 --- a/csrc/generation/rebuild_padding.cu +++ b/csrc/generation/rebuild_padding.cu @@ -109,6 +109,7 @@ std::vector RebuildPadding(const paddle::Tensor& tmp_out, const paddle::Tensor& seq_lens, const paddle::Tensor& input_ids) { switch (tmp_out.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return rebuild_padding( tmp_out, @@ -117,6 +118,7 @@ std::vector RebuildPadding(const paddle::Tensor& tmp_out, input_ids ); } +#endif case paddle::DataType::FLOAT16: { return rebuild_padding( tmp_out, diff --git a/csrc/generation/set_alibi_mask_value.cu b/csrc/generation/set_alibi_mask_value.cu index 8036f1096ebd..f0ee20d700fa 100644 --- a/csrc/generation/set_alibi_mask_value.cu +++ b/csrc/generation/set_alibi_mask_value.cu @@ -76,6 +76,7 @@ std::vector SetMaskValue(const paddle::Tensor& input_data, const paddle::Tensor& alibi_slopes, const paddle::Tensor& tgt_pos) { switch (input_data.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return set_mask_value( input_data, @@ -85,6 +86,7 @@ std::vector SetMaskValue(const paddle::Tensor& input_data, tgt_pos ); } +#endif case paddle::DataType::FLOAT16: { return set_mask_value( input_data, diff --git a/csrc/generation/set_mask_value.cu b/csrc/generation/set_mask_value.cu index bcd63a277de7..87a22e041329 100644 --- a/csrc/generation/set_mask_value.cu +++ b/csrc/generation/set_mask_value.cu @@ -14,6 +14,7 @@ #include "paddle/extension.h" + template class PDTraits; @@ -31,12 +32,14 @@ public: typedef paddle::float16 data_t; }; +#if CUDA_VERSION >= 11000 template <> class PDTraits { public: typedef __nv_bfloat16 DataType; typedef paddle::bfloat16 data_t; }; +#endif template __global__ void set_value_by_id(const int *seq_lens, const bool *stop_flags, T *output_data, int *sequence_lengths, int bs, int length) { @@ -77,6 +80,7 @@ std::vector set_mask_value(const paddle::Tensor& input_data, con std::vector SetMaskValue(const paddle::Tensor& input_data, const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens) { switch (input_data.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return set_mask_value( input_data, @@ -84,6 +88,7 @@ std::vector SetMaskValue(const paddle::Tensor& input_data, const seq_lens ); } +#endif case paddle::DataType::FLOAT16: { return set_mask_value( input_data, diff --git a/csrc/generation/token_penalty_multi_scores.cu b/csrc/generation/token_penalty_multi_scores.cu index 3ef010501921..d9836e87a845 100644 --- a/csrc/generation/token_penalty_multi_scores.cu +++ b/csrc/generation/token_penalty_multi_scores.cu @@ -156,6 +156,7 @@ std::vector TokenPenaltyMultiScores(const paddle::Tensor& pre_id const paddle::Tensor& eos_token_id) { switch (logits.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return token_penalty_multi_scores_kernel( pre_ids, @@ -168,6 +169,7 @@ std::vector TokenPenaltyMultiScores(const paddle::Tensor& pre_id eos_token_id ); } +#endif case paddle::DataType::FLOAT16: { return token_penalty_multi_scores_kernel( pre_ids, diff --git a/csrc/generation/transpose_removing_padding.cu b/csrc/generation/transpose_removing_padding.cu index 5b6b16a7faa2..018677ed79c0 100644 --- a/csrc/generation/transpose_removing_padding.cu +++ b/csrc/generation/transpose_removing_padding.cu @@ -125,6 +125,7 @@ std::vector ApplyTransposeRemovingPadding(const paddle::Tensor& const paddle::Tensor& seq_lens, const paddle::Tensor& padding_offset) { switch (input.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return apply_transpose_remove_padding( input, @@ -132,6 +133,7 @@ std::vector ApplyTransposeRemovingPadding(const paddle::Tensor& padding_offset ); } +#endif case paddle::DataType::FLOAT16: { return apply_transpose_remove_padding( input, diff --git a/csrc/generation/write_cache_kv.cu b/csrc/generation/write_cache_kv.cu index 62ebf854b0e0..01ba1a749f8f 100644 --- a/csrc/generation/write_cache_kv.cu +++ b/csrc/generation/write_cache_kv.cu @@ -154,11 +154,13 @@ void WriteCacheKV(const paddle::Tensor& input_k, const paddle::Tensor& cache_kv, const paddle::Tensor& sequence_lengths_shape) { switch (cache_kv.type()) { +#if CUDA_VERSION >= 11000 case paddle::DataType::BFLOAT16: { return LaunchWriteCacheKV( input_k, input_v, cache_kv, sequence_lengths_shape ); } +#endif case paddle::DataType::FLOAT16: { return LaunchWriteCacheKV( input_k, input_v, cache_kv, sequence_lengths_shape