diff --git a/csrc/gpu/tune_cublaslt_gemm.cu b/csrc/gpu/tune_cublaslt_gemm.cu index 69c4be8fccc716..2a5a10f5a54380 100644 --- a/csrc/gpu/tune_cublaslt_gemm.cu +++ b/csrc/gpu/tune_cublaslt_gemm.cu @@ -18,11 +18,11 @@ limitations under the License. */ #include #include +#include #include #include #include #include -#include #include "helper.h" @@ -105,6 +105,13 @@ static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a, return (algo_para_a.time < algo_para_b.time); } +// 获取当前 GPU 的剩余显存大小(以字节为单位) +size_t get_remaining_memory() { + size_t free, total; + CUDA_CHECK(cudaMemGetInfo(&free, &total)); + return free; +} + template static void TestMatmulRun(cublasLtHandle_t ltHandle, cublasLtMatmulDesc_t matmulDesc, @@ -122,7 +129,10 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle, cublasLtMatmulHeuristicResult_t heurResult; cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult); - if (algoStatus == CUBLAS_STATUS_SUCCESS) { + + auto remainingMemorySize = 0.95 * get_remaining_memory(); + if (algoStatus == CUBLAS_STATUS_SUCCESS && + remainingMemorySize > heurResult.workspaceSize) { ScaleT alpha = static_cast(1), beta = static_cast(0); void* workSpace; CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize)); @@ -166,8 +176,13 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle, } CUDA_CHECK(cudaFree(workSpace)); } else { - std::cerr << "not enough workspace! current workspace is " - << heurResult.workspaceSize; + std::cerr << "Not enough workspace! Required " + << static_cast(heurResult.workspaceSize) / 1024.0 / + 1024.0 / 1024.0 + << " GiB" << ", But remaining " + << static_cast(remainingMemorySize) / 1024.0 / 1024.0 / + 1024.0 + << " GiB" << std::endl; perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace } } @@ -442,7 +457,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle, if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) { std::clog << "algo " << algos[i].algoId << " tile " << algos[i].tile << " stages " << algos[i].stages << " splitK_val " - << algos[i].splitK_val; + << algos[i].splitK_val << std::endl; algos[i].time = std::numeric_limits::max(); std::cerr << " TestMatmulRun with status " << perfResults[i].status << std::endl; @@ -467,7 +482,7 @@ class DevContext {}; class CPUContext : public DevContext {}; class CUBLASLTContext : public DevContext { - public: +public: CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); } cublasLtHandle_t handle; @@ -709,64 +724,51 @@ void GEMMInt8(const CUBLASLTContext& dev_ctx, CUDA_CHECK(cudaFree(workSpace)); } -void TuneCublasltGemm(const paddle::Tensor& M, - const paddle::Tensor& K, +void TuneCublasltGemm(const paddle::Tensor& K, const paddle::Tensor& N, + const int M_start, + const int M_end, const std::string& dtype, - bool is_test, - bool is_read_from_file, + const bool is_test, + const bool is_read_from_file, const std::string& path) { - // Ensure that M, K, and N are all one-dimensional Tensors. is_test != - // is_read_from_file - assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1); + assert(M_end >= M_start); + assert(M_start >= 1); + assert(K.dims().size() == 1 && N.dims().size() == 1); assert(is_test != is_read_from_file); - auto M_cpu = M.copy_to(paddle::CPUPlace(), false); auto K_cpu = K.copy_to(paddle::CPUPlace(), false); auto N_cpu = N.copy_to(paddle::CPUPlace(), false); - int64_t* M_data = M_cpu.data(); int64_t* K_data = K_cpu.data(); int64_t* N_data = N_cpu.data(); - int M_size = M.numel(); int K_size = K.numel(); int N_size = N.numel(); assert(K_size == N_size); - int m_data = (int)M_data[0]; - assert(m_data > 0); - std::vector mm; - - int m = 1, step = 1; - while (m <= m_data) { - mm.push_back(m); - m += step; - + int m = M_start, step = 1; + while (m <= M_end) { // update step - switch (m) { - case 4: - step = 4; - break; - case 16: - step = 16; - break; - case 64: - step = 32; - break; - case 256: - step = 64; - break; - case 512: - step = 128; - break; - case 1024: - step = 1024; - break; - case 8192: - step = 4096; - break; + if (m >= 8192) { + step = 4096; + } else if (m >= 1024) { + step = 1024; + } else if (m >= 512) { + step = 128; + } else if (m >= 256) { + step = 64; + } else if (m >= 64) { + step = 32; + } else if (m >= 16) { + step = 16; + } else if (m >= 4) { + step = 4; + } else { + step = 1; } + mm.push_back(m); + m += step; } for (int j = 0; j < mm.size(); j++) { @@ -792,16 +794,18 @@ void TuneCublasltGemm(const paddle::Tensor& M, path); } else { // other dtype - std::cout << "Not currently supported" << std::endl; + throw std::runtime_error(dtype + "not currently supported"); } } } } PD_BUILD_OP(tune_cublaslt_gemm) - .Inputs({"M", "K", "N"}) + .Inputs({"K", "N"}) .Outputs({}) - .Attrs({"dtype: std::string", + .Attrs({"M_start: int", + "M_end: int", + "dtype: std::string", "is_test: bool", "is_read_from_file: bool", "path: std::string"}) diff --git a/csrc/utils/tune_cublaslt_int8_gemm.py b/csrc/utils/tune_cublaslt_int8_gemm.py index add9051af25fa0..7e69d0602a109d 100644 --- a/csrc/utils/tune_cublaslt_int8_gemm.py +++ b/csrc/utils/tune_cublaslt_int8_gemm.py @@ -15,7 +15,8 @@ import paddle from paddlenlp_ops import tune_cublaslt_gemm -M_tensor = paddle.to_tensor([32768]) +M_start = 1 +M_end = 32768 # llama3.1-8b k1 = [4096, 4096, 4096, 14336] @@ -36,7 +37,11 @@ K_tensor = paddle.to_tensor(k1 + k2 + k3 + k4) N_tensor = paddle.to_tensor(n1 + n2 + n3 + n4) -Dtype = "int8" Path = "./cublaslt_gemm_search.csv" -tune_cublaslt_gemm(M_tensor, K_tensor, N_tensor, Dtype, True, False, Path) +tune_cublaslt_gemm(K_tensor, N_tensor, M_start, M_end, "int8", True, False, Path) + +# shape 计算公式 +# [qkv, out_linear, ffn1, ffn2] +# k = [hidden_size, hidden_size, hidden_size, intermediate_size//mp_size] +# n = [((num_attention_heads//mp_size)+2*(num_key_value_heads//mp_size))*(hidden_size//num_attention_heads), hidden_size, 2*(intermediate_size//mp_size), hidden_size] diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index d55db079a6e691..6e71b16e5b3ea3 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -876,6 +876,8 @@ def set_state_dict(self, state_dict): ffn_hidden_size=self.intermediate_size, num_key_value_heads=self.num_key_value_heads, mp_size=self.config.tensor_parallel_degree, + concat_qkv=True, + concat_ffn1=True, ) self.transformer_block.weight_scales = weight_scales_loader.scale self.transformer_block.act_scales = act_scale_loader.scale @@ -1097,16 +1099,24 @@ def set_state_dict(self, state_dict): dtype=paddle.get_default_dtype(), ) self.transformer_block.linear_shifts[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.linear_smooths[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]) + paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] + ).astype(paddle.get_default_dtype()) ) self.transformer_block.ffn2_shifts[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.ffn2_smooths[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]) + paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype( + paddle.get_default_dtype() + ) ) if self.shift: diff --git a/paddlenlp/experimental/transformers/mixtral/modeling.py b/paddlenlp/experimental/transformers/mixtral/modeling.py index d8ba9198394b37..67cb3183132fdb 100644 --- a/paddlenlp/experimental/transformers/mixtral/modeling.py +++ b/paddlenlp/experimental/transformers/mixtral/modeling.py @@ -716,16 +716,24 @@ def set_state_dict(self, state_dict): if "a8w8" in self.quant_type: if self.shift_smooth_all_linears: self.transformer_block.linear_shifts[idx].set_value( - paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)]) + paddle.to_tensor( + state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)] + ).astype(paddle.get_default_dtype()) ) self.transformer_block.linear_smooths[idx].set_value( - paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]) + paddle.to_tensor( + state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] + ).astype(paddle.get_default_dtype()) ) self.transformer_block.ffn2_shifts[idx].set_value( - paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.ffn2_smooths[idx].set_value( - paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)]) + paddle.to_tensor( + state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)] + ).astype(paddle.get_default_dtype()) ) if self.shift: diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index b8b748ac59911c..f02f3fc5c27790 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -453,6 +453,8 @@ def set_state_dict(self, state_dict): ffn_hidden_size=self.intermediate_size, num_key_value_heads=self.num_key_value_heads, mp_size=self.config.tensor_parallel_degree, + concat_qkv=True, + concat_ffn1=True, ) self.transformer_block.weight_scales = weight_scales_loader.scale self.transformer_block.act_scales = act_scale_loader.scale @@ -704,16 +706,24 @@ def set_state_dict(self, state_dict): dtype=paddle.get_default_dtype(), ) self.transformer_block.linear_shifts[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.linear_smooths[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]) + paddle.to_tensor( + state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] + ).astype(paddle.get_default_dtype()) ) self.transformer_block.ffn2_shifts[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.ffn2_smooths[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]) + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype( + paddle.get_default_dtype() + ) ) if self.shift: diff --git a/paddlenlp/experimental/transformers/utils.py b/paddlenlp/experimental/transformers/utils.py index d24904c1f31bf5..8a8eed55d58d26 100644 --- a/paddlenlp/experimental/transformers/utils.py +++ b/paddlenlp/experimental/transformers/utils.py @@ -108,6 +108,8 @@ def __init__( ffn_hidden_size, num_key_value_heads=-1, mp_size=1, + concat_qkv=False, + concat_ffn1=False, ): self.key_map = key_map_dict self.scale = {} @@ -126,6 +128,17 @@ def __init__( n = num_head * dim_head self.scale[scale_type] = np.full([num_of_layers, n], fill_value=0.1, dtype="float32") + # concat qkv and ffn1 + if concat_qkv: + self.scale["qkv_weight_scale"] = np.full( + [num_of_layers, qkv_out_size // mp_size], fill_value=0.1, dtype="float32" + ) + + if concat_ffn1: + self.scale["ffn1_weight_scale"] = np.full( + [num_of_layers, ffn_hidden_size * 2 // mp_size], fill_value=0.1, dtype="float32" + ) + class EmptyCacheScale: """