diff --git a/cinn/frontend/cinn_builder.cc b/cinn/frontend/cinn_builder.cc index c5e572f05f..f76029fb9d 100644 --- a/cinn/frontend/cinn_builder.cc +++ b/cinn/frontend/cinn_builder.cc @@ -216,17 +216,21 @@ Variable CinnBuilder::Reverse(const Variable& operand, const std::vector& a return instr.GetOutput(0); } -std::vector CinnBuilder::BnMeanVarianceReduce(const Variable& x) { - Instruction instr("bn_mean_variance_reduce", {x}); +std::vector CinnBuilder::BnMeanVariance(const Variable& x) { + Instruction instr("bn_mean_variance", {x}); + // optimize bn forward reduce computation, set reduce dimension(NCHW suppport only, to be deprecated). + instr.SetAttr("dim", std::vector{0, 2, 3}); + instr.SetAttr("keep_dim", false); InferShape(instr); AppendInstruction(instr); return instr.GetOutputs(); } -std::vector CinnBuilder::BnGradBiasScaleReduce(const Variable& x, - const Variable& x_mean, - const Variable& y_grad) { - Instruction instr("bn_grad_bias_scale_reduce", {x, x_mean, y_grad}); +std::vector CinnBuilder::BnGradBiasScale(const Variable& x, const Variable& x_mean, const Variable& y_grad) { + Instruction instr("bn_grad_bias_scale", {x, x_mean, y_grad}); + // optimize bn backward reduce computation, set reduce dimension(NCHW suppport only, to be deprecated). + instr.SetAttr("dim", std::vector{0, 2, 3}); + instr.SetAttr("keep_dim", false); InferShape(instr); AppendInstruction(instr); return instr.GetOutputs(); diff --git a/cinn/frontend/cinn_builder.h b/cinn/frontend/cinn_builder.h index 417618495d..add98cb35b 100644 --- a/cinn/frontend/cinn_builder.h +++ b/cinn/frontend/cinn_builder.h @@ -179,9 +179,9 @@ class CinnBuilder : public BaseBuilder { Variable Reverse(const Variable& operand, const std::vector& axis); - std::vector BnMeanVarianceReduce(const Variable& x); + std::vector BnMeanVariance(const Variable& x); - std::vector BnGradBiasScaleReduce(const Variable& x, const Variable& x_mean, const Variable& y_grad); + std::vector BnGradBiasScale(const Variable& x, const Variable& x_mean, const Variable& y_grad); private: Variable UnaryOp(const std::string& op_type, const Variable& operand); diff --git a/cinn/frontend/decomposer/batch_norm.cc b/cinn/frontend/decomposer/batch_norm.cc index 8039c4f207..be53885da1 100644 --- a/cinn/frontend/decomposer/batch_norm.cc +++ b/cinn/frontend/decomposer/batch_norm.cc @@ -60,15 +60,13 @@ struct BatchNormHelper { std::vector MeanAndVariance(Variable x) { #ifdef CINN_WITH_CUDA // To optimize the bn forward by merge the reduce computation of mean and variance, - // build a fusion op 'BnMeanVarianceReduce' by hand as the fusion pass is not support now. + // build a fusion op 'BnMeanVariance' by hand as the fusion pass is not support now. // When the fusion pass is rebuild, this op is to be removed. - auto vars = builder->BnMeanVarianceReduce(x); + auto vars = builder->BnMeanVariance(x); auto element_count_1d_0 = GetTensorFromScalar(element_count, "element_count", param_shape); auto element_count_1d_1 = GetTensorFromScalar(element_count, "element_count", param_shape); - auto mean = builder->Div(builder->Reduce(vars[0], ReduceKind::kSum, std::vector(1, vars[0]->shape.size() - 1)), - element_count_1d_0); - auto mean_squre = builder->Div( - builder->Reduce(vars[1], ReduceKind::kSum, std::vector(1, vars[1]->shape.size() - 1)), element_count_1d_1); + auto mean = builder->Div(vars[0], element_count_1d_0); + auto mean_squre = builder->Div(vars[1], element_count_1d_1); auto variance = builder->Sub(mean_squre, builder->Mul(mean, builder->Identity(mean))); #else @@ -82,11 +80,9 @@ struct BatchNormHelper { std::vector GradBiasAndScale(Variable x, Variable x_mean, Variable y_grad) { #ifdef CINN_WITH_CUDA - // Using fusion op "BnGradBiasScaleReduce" as the same reason with "BnMeanVarianceReduce". + // Using fusion op "BnGradBiasScale" as the same reason with "BnMeanVariance". // It also will be removed. - auto vars = builder->BnGradBiasScaleReduce(x, x_mean, y_grad); - return {builder->Reduce(vars[0], ReduceKind::kSum, std::vector(1, vars[0]->shape.size() - 1)), - builder->Reduce(vars[1], ReduceKind::kSum, std::vector(1, vars[1]->shape.size() - 1))}; + return builder->BnGradBiasScale(x, x_mean, y_grad); #else auto mean_4d = builder->BroadcastTo(x_mean, x->shape, {channel_dim}); auto x_mean_diff = builder->Sub(x, mean_4d); diff --git a/cinn/hlir/op/reduction.cc b/cinn/hlir/op/reduction.cc index 119d45203e..f4d0f82a9a 100644 --- a/cinn/hlir/op/reduction.cc +++ b/cinn/hlir/op/reduction.cc @@ -41,29 +41,34 @@ using pe::ReduceSum; using PeFunc = std::function &, bool, Expr, const std::string &)>; std::vector GetShape(const ir::Tensor &x) { - auto last_reduce_dim = x->shape[2].as_int32() * x->shape[2].as_int32(); - // split into last_reduce_dim into {n,k} + auto last_reduce_dim = x->shape[2].as_int32() * x->shape[3].as_int32(); + // Split into last_reduce_dim into {n,k} std::vector new_shape = {x->shape[0].as_int32(), x->shape[1].as_int32()}; - if (last_reduce_dim <= 128) { + // As the max block size is 1024, setting 1024 as limit + if (last_reduce_dim <= 1024) { new_shape.push_back(last_reduce_dim); } else { - for (int idx = 256; idx > 128; --idx) { + // As sum of reduce dimension is over 1024, so find a value along(1024, 1) that can be divied by + // last_reduce_dim. + for (int idx = 1024;; --idx) { if (last_reduce_dim % idx == 0) { new_shape.push_back(last_reduce_dim / idx); new_shape.push_back(idx); break; } } + + CHECK_EQ(new_shape.size(), 4) << "Can't find a new shape that satisfy the requirement!"; } return new_shape; } -std::shared_ptr StrategyForBnMeanVarianceReduce(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForBnMeanVariance(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { CHECK_EQ(inputs.size(), 1) << "bn_mean_variance should has 1 input!"; auto input = inputs[0]; CHECK_EQ(input->shape.size(), 4) << "bn_mean_variance input shape should be 4 dimension!"; @@ -78,41 +83,76 @@ std::shared_ptr StrategyForBnMeanVarianceReduce(const framework::Nod CHECK(A.as_tensor()); auto x = A.as_tensor_ref(); - auto stages = CreateStages({x}); - auto x_reshape = pe::Reshape(x, new_shape, stages, UniqName("bn_mean_variance_x_reshape_out")); - auto x_square = pe::Multiply(x_reshape, x_reshape, UniqName("bn_mean_variance_x_square")); - + auto stages = CreateStages({x}); + auto x_reshape = pe::Reshape(x, new_shape, stages, UniqName("bn_mean_variance_x_reshape_out")); + auto x_square = pe::Multiply(x_reshape, x_reshape, UniqName("bn_mean_variance_x_square")); auto reduce_dim = new_shape.size() == 3 ? std::vector{0} : std::vector{0, 2}; - auto out0 = pe::ReduceSum(x_reshape, reduce_dim, false, Expr(0.0f), UniqName("bn_mean_variance_out0")); - auto out1 = pe::ReduceSum(x_square, reduce_dim, false, Expr(0.0f), UniqName("bn_mean_variance_out1")); + + auto x_sum_local = pe::ReduceSum(x_reshape, reduce_dim, false, Expr(0.0f), UniqName("bn_mean_variance_out0")); + auto x_square_sum_local = pe::ReduceSum(x_square, reduce_dim, false, Expr(0.0f), UniqName("bn_mean_variance_out1")); + + auto x_sum_out = pe::BlockReduceSumInternal(x_sum_local, 1); + auto x_square_out = pe::BlockReduceSumInternal(x_square_sum_local, 1); + + CHECK_EQ(x_sum_out.size(), 2); + CHECK_EQ(x_square_out.size(), 2); stages->InsertLazily(x_reshape); stages->InsertLazily(x_square); - stages->InsertLazily(out0); - stages->InsertLazily(out1); + stages->InsertLazily(x_sum_local); + stages->InsertLazily(x_square_sum_local); + stages->InsertLazily(x_sum_out[1]); + stages->InsertLazily(x_square_out[1]); + stages->InsertLazily(x_sum_out[0]); + stages->InsertLazily(x_square_out[0]); + stages[x_reshape]->ComputeInline(); stages[x_square]->ComputeInline(); - *ret = CINNValuePack{{CINNValue(out0), CINNValue(out1), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(x_sum_local), + CINNValue(x_square_sum_local), + CINNValue(x_sum_out[1]), + CINNValue(x_square_out[1]), + CINNValue(x_sum_out[0]), + CINNValue(x_square_out[0]), + CINNValue(stages)}}; }); framework::CINNSchedule bn_mean_variance_schedule([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of bn_mean_variance schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 3UL); + CHECK_EQ(arg_pack.size(), 7UL); if (target.arch == Target::Arch::NVGPU) { - Expr out0 = arg_pack[0]; - Expr out1 = arg_pack[1]; - poly::StageMap stages = arg_pack.back(); - CHECK(out0.as_tensor()); - CHECK(out1.as_tensor()); - pe::CudaScheduleReduce(stages, out0.as_tensor_ref(), target); - pe::CudaScheduleReduce(stages, out1.as_tensor_ref(), target); + Expr x_sum_local = arg_pack[0]; + Expr x_square_sum_local = arg_pack[1]; + Expr x_sum_tmp = arg_pack[2]; + Expr x_square_sum_tmp = arg_pack[3]; + Expr x_sum = arg_pack[4]; + Expr x_square_sum = arg_pack[5]; + poly::StageMap stages = arg_pack.back(); + CHECK(x_sum_local.as_tensor()); + CHECK(x_square_sum_local.as_tensor()); + CHECK(x_sum_tmp.as_tensor()); + CHECK(x_square_sum_tmp.as_tensor()); + CHECK(x_sum.as_tensor()); + CHECK(x_square_sum.as_tensor()); + + pe::CudaScheduleBlockReduce(stages, + x_sum_local.as_tensor_ref(), + x_sum_tmp.as_tensor_ref(), + x_sum.as_tensor_ref(), + common::DefaultNVGPUTarget()); + + // set x_square compute at x + stages[x_square_sum_local.as_tensor_ref()]->SetBuffer("local"); if (new_shape.size() == 3) { - stages[out0.as_tensor_ref()]->SimpleComputeAt(stages[out1.as_tensor_ref()], 2); + stages[x_square_sum_local.as_tensor_ref()]->SimpleComputeAt(stages[x_sum_local.as_tensor_ref()], 2); } else { - stages[out0.as_tensor_ref()]->SimpleComputeAt(stages[out1.as_tensor_ref()], 3); + stages[x_square_sum_local.as_tensor_ref()]->SimpleComputeAt(stages[x_sum_local.as_tensor_ref()], 3); } + stages[x_square_sum_tmp.as_tensor_ref()]->SetBuffer("local"); + stages[x_square_sum_tmp.as_tensor_ref()]->SimpleComputeAt(stages[x_sum_tmp.as_tensor_ref()], 1); + stages[x_square_sum.as_tensor_ref()]->SimpleComputeAt(stages[x_sum.as_tensor_ref()], 0); } else if (target.arch == Target::Arch::X86) { Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; @@ -132,11 +172,11 @@ std::shared_ptr StrategyForBnMeanVarianceReduce(const framework::Nod return strategy; } -std::shared_ptr StrategyForBnGradBiasScaleReduce(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForBnGradBiasScale(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { CHECK_EQ(inputs.size(), 3) << "bn_grad_bias_scale should has 3 input!"; auto input = inputs[0]; CHECK_EQ(input->shape.size(), 4) << "bn_grad_bias_scale input shape should be 4 dimension!"; @@ -167,40 +207,76 @@ std::shared_ptr StrategyForBnGradBiasScaleReduce(const framework::No auto reduce_dim = new_shape.size() == 3 ? std::vector{0} : std::vector{0, 2}; - auto out0 = pe::ReduceSum(y_grad_reshape, reduce_dim, false, Expr(0.0f), UniqName("bn_grad_bias_scale_out0")); - auto out1 = pe::ReduceSum(grad_x_mean_diff, reduce_dim, false, Expr(0.0f), UniqName("bn_grad_bias_scale_out1")); + auto reduce_local_bias = + pe::ReduceSum(y_grad_reshape, reduce_dim, false, Expr(0.0f), UniqName("bn_grad_bias_scale_out0")); + auto reduce_local_diff = + pe::ReduceSum(grad_x_mean_diff, reduce_dim, false, Expr(0.0f), UniqName("bn_grad_bias_scale_out1")); + + auto reduce_sum_bias = pe::BlockReduceSumInternal(reduce_local_bias, 1); + auto reduce_sum_diff = pe::BlockReduceSumInternal(reduce_local_diff, 1); + + CHECK_EQ(reduce_sum_bias.size(), 2); + CHECK_EQ(reduce_sum_diff.size(), 2); - // auto stages = CreateStages({x_mean_diff, grad_x_mean_diff, out0, out1}); stages->InsertLazily(x_reshape); stages->InsertLazily(y_grad_reshape); stages->InsertLazily(x_mean_diff); stages->InsertLazily(grad_x_mean_diff); - stages->InsertLazily(out0); - stages->InsertLazily(out1); + stages->InsertLazily(reduce_local_bias); + stages->InsertLazily(reduce_local_diff); + stages->InsertLazily(reduce_sum_bias[1]); + stages->InsertLazily(reduce_sum_diff[1]); + stages->InsertLazily(reduce_sum_bias[0]); + stages->InsertLazily(reduce_sum_diff[0]); + stages[x_reshape]->ComputeInline(); stages[y_grad_reshape]->ComputeInline(); stages[x_mean_diff]->ComputeInline(); stages[grad_x_mean_diff]->ComputeInline(); - *ret = CINNValuePack{{CINNValue(out0), CINNValue(out1), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(reduce_local_bias), + CINNValue(reduce_local_diff), + CINNValue(reduce_sum_bias[1]), + CINNValue(reduce_sum_diff[1]), + CINNValue(reduce_sum_bias[0]), + CINNValue(reduce_sum_diff[0]), + CINNValue(stages)}}; }); framework::CINNSchedule bn_grad_bias_scale_schedule([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of bn_grad_bias_scale schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 3UL); + CHECK_EQ(arg_pack.size(), 7UL); if (target.arch == Target::Arch::NVGPU) { - Expr out0 = arg_pack[0]; - Expr out1 = arg_pack[1]; + Expr reduce_local_bias = arg_pack[0]; + Expr reduce_local_diff = arg_pack[1]; + Expr reduce_sum_bias_tmp = arg_pack[2]; + Expr reduce_sum_diff_tmp = arg_pack[3]; + Expr reduce_sum_bias = arg_pack[4]; + Expr reduce_sum_diff = arg_pack[5]; + poly::StageMap stages = arg_pack.back(); - CHECK(out0.as_tensor()); - CHECK(out1.as_tensor()); - pe::CudaScheduleReduce(stages, out0.as_tensor_ref(), target); - pe::CudaScheduleReduce(stages, out1.as_tensor_ref(), target); + CHECK(reduce_local_bias.as_tensor()); + CHECK(reduce_local_diff.as_tensor()); + CHECK(reduce_sum_bias_tmp.as_tensor()); + CHECK(reduce_sum_diff_tmp.as_tensor()); + CHECK(reduce_sum_bias.as_tensor()); + CHECK(reduce_sum_diff.as_tensor()); + + pe::CudaScheduleBlockReduce(stages, + reduce_local_bias.as_tensor_ref(), + reduce_sum_bias_tmp.as_tensor_ref(), + reduce_sum_bias.as_tensor_ref(), + common::DefaultNVGPUTarget()); + + stages[reduce_local_diff.as_tensor_ref()]->SetBuffer("local"); if (new_shape.size() == 3) { - stages[out0.as_tensor_ref()]->SimpleComputeAt(stages[out1.as_tensor_ref()], 2); + stages[reduce_local_diff.as_tensor_ref()]->SimpleComputeAt(stages[reduce_local_bias.as_tensor_ref()], 2); } else { - stages[out0.as_tensor_ref()]->SimpleComputeAt(stages[out1.as_tensor_ref()], 3); + stages[reduce_local_diff.as_tensor_ref()]->SimpleComputeAt(stages[reduce_local_bias.as_tensor_ref()], 3); } + stages[reduce_sum_diff_tmp.as_tensor_ref()]->SetBuffer("local"); + stages[reduce_sum_diff_tmp.as_tensor_ref()]->SimpleComputeAt(stages[reduce_sum_bias_tmp.as_tensor_ref()], 1); + stages[reduce_sum_diff.as_tensor_ref()]->SimpleComputeAt(stages[reduce_sum_bias.as_tensor_ref()], 0); } else if (target.arch == Target::Arch::X86) { Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; @@ -220,27 +296,6 @@ std::shared_ptr StrategyForBnGradBiasScaleReduce(const framework::No return strategy; } -std::vector InferShapeForBNReduce(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(inputs_shape.size() == 3UL || inputs_shape.size() == 1UL); - CHECK_EQ(inputs_shape[0].size(), 4UL); - // compute the succesive dimension size - auto last_reduce_dim = inputs_shape[0][2] * inputs_shape[0][3]; - // split into last_reduce_dim into {n,k} - std::vector output_shape = {inputs_shape[0][1]}; - if (last_reduce_dim <= 128) { - output_shape.push_back(last_reduce_dim); - } else { - for (int idx = 256; idx > 128; --idx) { - if (last_reduce_dim % idx == 0) { - output_shape.push_back(idx); - break; - } - } - } - return {output_shape, output_shape}; -} - #define StrategyForReduction(op_name__, pe__, pe_func__) \ std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ const std::vector &inputs, \ @@ -262,36 +317,87 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, if (attrs.attr_store.count("dim")) { dim = absl::get>(attrs.attr_store.at("dim")); std::sort(dim.begin(), dim.end()); + // check dim + CHECK_LE(dim.size(), inputs[0]->shape.size()); + CHECK_LT(dim.back(), inputs[0]->shape.size()); + for (int idx = 1; idx < dim.size(); ++idx) { + CHECK_NE(dim[idx - 1], dim[idx]); + } + } else { + LOG(FATAL) << "reduce dimension is not set!"; } + if (attrs.attr_store.count("keep_dim")) { keep_dim = absl::get(attrs.attr_store.at("keep_dim")); } + + // compute reduce args + int succesive_dim_idx = 0; + bool reduce_dim_succesive = true; + int last_succesive_dim = inputs[0]->shape.back().as_int32(); + for (int idx = dim.size() - 2; idx >= 0; --idx) { + if (dim[idx] != dim[idx + 1] - 1) { + succesive_dim_idx = idx + 1; + reduce_dim_succesive = false; + break; + } else { + last_succesive_dim *= inputs[0]->shape[dim[idx]].as_int32(); + } + } + framework::CINNCompute reduction_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; CINNValuePack a = args[0]; CHECK_EQ(a.size(), 1U) << "1 input tensor for " << op_name << " compute"; - Expr A_expr = a[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - // if do reduce on last axis and reduce axis size > 1. - // two step to do reduce: 1.[n,c,h,w] -> [c,w]; 2.[c,w] -> [c] - if (dim.back() == inputs[0]->shape.size() - 1 && dim.size() > 1 && target == common::DefaultNVGPUTarget()) { - // reduce without last dimension - std::vector dims_without_last(dim.begin(), --dim.end()); - auto out0 = pe_func(A, dims_without_last, keep_dim, Expr(), UniqName(op_name + "_out0")); - // reduce on last dimension - std::vector dims_last(1, out0->shape.size() - 1); - auto out1 = pe_func(out0, dims_last, keep_dim, Expr(), UniqName(op_name + "_out1")); - auto stages = CreateStages({A, out0, out1}); - *ret = CINNValuePack{{CINNValue(out1), CINNValue(out0), CINNValue(stages)}}; - } else if (dim.back() == inputs[0]->shape.size() - 1 && target == common::DefaultNVGPUTarget()) { - auto res = pe::WarpReduceSum(A, dim.size()); - auto stages = CreateStages(res); - *ret = CINNValuePack{{CINNValue(res[0]), CINNValue(res[1]), CINNValue(stages)}}; + Expr x_expr = a[0]; + CHECK(x_expr.as_tensor()); + ir::Tensor x = x_expr.as_tensor_ref(); + if (target == common::DefaultNVGPUTarget() && dim.back() == inputs[0]->shape.size() - 1) { + // the reduce dimension is succesive + if (reduce_dim_succesive) { + if (last_succesive_dim < 256) { + VLOG(3) << "Do WarpReduceSum Compute!"; + // if the succesive reduce dimension size < 256 + auto res = pe::WarpReduceSum(x, dim.size(), keep_dim); + CHECK_EQ(res.size(), 2); + auto stages = CreateStages(res); + *ret = CINNValuePack{{CINNValue(res[0]), CINNValue(res[1]), CINNValue(stages)}}; + } else { + VLOG(3) << "Do BlockReduceSum Compute!"; + // if the succesive reduce dimension size > 256 + int block_size = last_succesive_dim > 1024 ? 512 : 128; + auto res = pe::BlockReduceSum(x, dim.size(), block_size, keep_dim); + CHECK_EQ(res.size(), 2); + auto stages = CreateStages(res); + *ret = CINNValuePack{{CINNValue(res[0]), CINNValue(res[1]), CINNValue(stages)}}; + } + } else /* the reduce dimension is not succesive */ { + VLOG(3) << "Do ReduceSum And BlockReduceSumInternal Compute!"; + // compute the parallel reduce dimension size + int last_succesive_dim_tmp = last_succesive_dim; + std::vector reduce_without_last_diemension(dim.begin(), dim.begin() + succesive_dim_idx); + for (int idx = dim[succesive_dim_idx]; idx < inputs[0]->shape.size(); idx++) { + if (last_succesive_dim_tmp > 1024) { + last_succesive_dim_tmp /= inputs[0]->shape[idx].as_int32(); + reduce_without_last_diemension.push_back(idx); + } else { + break; + } + } + // TODO(sunli) : support last dimension size over 1024 + CHECK_LE(last_succesive_dim_tmp, 1024) << "last dimension size over 1024"; + // first: do reduce without last dimension + auto out = pe_func(x, reduce_without_last_diemension, keep_dim, Expr(), UniqName(op_name + "_out")); + // second: do reduce on last dimension + auto res = pe::BlockReduceSumInternal(out, dim.size() - reduce_without_last_diemension.size()); + CHECK_EQ(res.size(), 2); + auto stages = CreateStages({res[0], res[1], out}); + *ret = CINNValuePack{{CINNValue(res[0]), CINNValue(res[1]), CINNValue(out), CINNValue(stages)}}; + } } else { - // do reduce on last dimension - auto out = pe_func(A, dim, keep_dim, Expr(), UniqName(op_name + "_out")); - auto stages = CreateStages({A, out}); + VLOG(3) << "Do ReduceSum Compute!"; + auto out = pe_func(x, dim, keep_dim, Expr(), UniqName(op_name + "_out")); + auto stages = CreateStages({out}); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; } }); @@ -299,22 +405,43 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, framework::CINNSchedule reduction_schedule([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of " << op_name << " schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; - CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); - + CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL || arg_pack.size() == 4UL); if (target.arch == Target::Arch::NVGPU) { - if (dim.size() == 1 && dim.back() == inputs[0]->shape.size() - 1) { - Expr out = arg_pack[0]; - Expr tmp_out = arg_pack[1]; - poly::StageMap stages = arg_pack.back(); - pe::CudaScheduleWarpReduce(stages, tmp_out.as_tensor_ref(), out.as_tensor_ref(), target); + if (dim.back() == inputs[0]->shape.size() - 1) { + if (reduce_dim_succesive) { + CHECK_EQ(arg_pack.size(), 3UL); + Expr out = arg_pack[0]; + Expr tmp_out = arg_pack[1]; + poly::StageMap stages = arg_pack.back(); + if (last_succesive_dim < 256) { + VLOG(3) << "Do CudaScheduleWarpReduce Schedule!"; + pe::CudaScheduleWarpReduce( + stages, tmp_out.as_tensor_ref(), out.as_tensor_ref(), common::DefaultNVGPUTarget()); + } else { + VLOG(3) << "Do CudaScheduleBlockReduceInternal Schedule!"; + pe::CudaScheduleBlockReduceInternal( + stages, tmp_out.as_tensor_ref(), out.as_tensor_ref(), common::DefaultNVGPUTarget()); + } + } else { + CHECK_EQ(arg_pack.size(), 4UL); + Expr out = arg_pack[0]; + Expr tmp_out = arg_pack[1]; + Expr reduce_tmp_out = arg_pack[2]; + poly::StageMap stages = arg_pack.back(); + + VLOG(3) << "Do CudaScheduleBlockReduce Schedule!"; + pe::CudaScheduleBlockReduce(stages, + reduce_tmp_out.as_tensor_ref(), + tmp_out.as_tensor_ref(), + out.as_tensor_ref(), + common::DefaultNVGPUTarget()); + } } else { + CHECK_EQ(arg_pack.size(), 2UL); + Expr out = arg_pack[0]; poly::StageMap stages = arg_pack.back(); - Expr out0 = arg_pack.size() == 2 ? arg_pack[0] : arg_pack[1]; - pe::CudaScheduleReduce(stages, out0.as_tensor_ref(), target); - if (arg_pack.size() == 3) { - Expr out1 = arg_pack[0]; - pe::CudaScheduleReduce(stages, out1.as_tensor_ref(), target); - } + VLOG(3) << "Do CudaScheduleReduce Schedule!"; + pe::CudaScheduleReduce(stages, out.as_tensor_ref(), inputs[0]->shape.size() - dim.back() - 1, target); } } *ret = arg_pack; @@ -328,7 +455,7 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, std::vector InferShapeForReduction(const std::vector &inputs_shape, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1UL); + CHECK(inputs_shape.size() == 1UL || inputs_shape.size() == 3UL); std::vector dim; bool keep_dim = false; if (attrs.find("dim") != attrs.end()) { @@ -340,46 +467,28 @@ std::vector InferShapeForReduction(const std::vector &inputs_s } CHECK(!dim.empty()) << "should have reduce dim, please check!"; CHECK_LE(dim.size(), inputs_shape[0].size()) << "reduce dim should no more than the input size"; - std::vector out_shapes, out_shapes_internal; + std::vector out_shapes; auto ndim = inputs_shape[0].size(); - if (keep_dim) { - for (size_t i = 0; i < ndim; ++i) { - if (std::find(dim.begin(), dim.end(), i) != dim.end()) { + for (size_t i = 0; i < ndim; ++i) { + if (std::find(dim.begin(), dim.end(), i) != dim.end()) { + if (keep_dim) { out_shapes.push_back(1); - } else { - out_shapes.push_back(inputs_shape[0][i]); - } - } - - if (std::find(dim.begin(), dim.end(), inputs_shape[0].size() - 1) != dim.end()) { - out_shapes_internal = out_shapes; - out_shapes_internal.back() = inputs_shape[0].back(); - } - } else { - for (size_t i = 0; i < ndim; ++i) { - if (std::find(dim.begin(), dim.end(), i) == dim.end()) { - out_shapes.push_back(inputs_shape[0][i]); } - } - - if (std::find(dim.begin(), dim.end(), inputs_shape[0].size() - 1) != dim.end()) { - out_shapes_internal = out_shapes; - out_shapes_internal.push_back(inputs_shape[0].back()); + } else { + out_shapes.push_back(inputs_shape[0][i]); } } + if (out_shapes.empty()) { out_shapes.push_back(1); } - if (out_shapes_internal.empty()) { - out_shapes_internal.push_back(1); - } - return {out_shapes, out_shapes_internal}; + return {out_shapes}; } std::vector InferDtypeForReduction(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; - std::vector res{inputs_type[0], inputs_type[0]}; + std::vector res{inputs_type[0]}; return res; } @@ -394,9 +503,27 @@ std::vector> InferLayoutForReduction(const std::vector< new_input_layouts[0] = "NCHW"; VLOG(3) << "alter input layout from " << input_layouts[0] << " to " << new_input_layouts[0]; } - new_input_layouts.push_back(""); - return {{"", ""}, new_input_layouts}; + return {{""}, new_input_layouts}; +} + +std::vector InferShapeForBnOptimize(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + auto shapes = InferShapeForReduction(inputs_shape, attrs); + CHECK_GE(shapes.size(), 1) << "shapes's size less than 1, please check!"; + return {shapes[0], shapes[0]}; +} + +std::vector InferDtypeForBnOptimize(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + return {inputs_type[0], inputs_type[0]}; +} + +std::vector> InferLayoutForBnOptimize(const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + return {{"", ""}, {"", ""}}; } StrategyForReduction(reduce_sum, ReduceSum, PeFunc); @@ -410,17 +537,24 @@ StrategyForReduction(reduce_min, ReduceMin, PeFunc); } // namespace hlir } // namespace cinn +// TODO(sunli) : repair element-wise + reduce fusion on gpu +#ifdef CINN_WITH_CUDA +#define REDUCE_OP_PATTERN_KIND cinn::hlir::framework::OpPatternKind::kOpaque +#else +#define REDUCE_OP_PATTERN_KIND cinn::hlir::framework::OpPatternKind::kCommReduce +#endif + CINN_REGISTER_HELPER(reduce_ops) { -#define CINN_REGISTER_REDUCTION(op__, op_stragegy__) \ - CINN_REGISTER_OP(op__) \ - .describe(#op__ " function") \ - .set_num_inputs(1) \ - .set_num_outputs(2) \ - .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \ - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kCommReduce) \ +#define CINN_REGISTER_REDUCTION(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \ + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction)) \ + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \ + .set_attr("OpPattern", REDUCE_OP_PATTERN_KIND) \ .set_support_level(4); CINN_REGISTER_REDUCTION(reduce_sum, ReduceSum); @@ -430,27 +564,25 @@ CINN_REGISTER_HELPER(reduce_ops) { #undef CINN_REGISTER_REDUCTION - CINN_REGISTER_OP(bn_mean_variance_reduce) + CINN_REGISTER_OP(bn_mean_variance) .describe("This operator implements the optimization of bn reduce") .set_num_inputs(1) .set_num_outputs(2) - .set_attr("CINNStrategy", - cinn::hlir::op::StrategyForBnMeanVarianceReduce) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBNReduce)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBnMeanVariance) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBnOptimize)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBnOptimize)) + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBnOptimize)) .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) .set_support_level(4); - CINN_REGISTER_OP(bn_grad_bias_scale_reduce) + CINN_REGISTER_OP(bn_grad_bias_scale) .describe("This operator implements the optimization of bn grad reduce") .set_num_inputs(3) .set_num_outputs(2) - .set_attr("CINNStrategy", - cinn::hlir::op::StrategyForBnGradBiasScaleReduce) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBNReduce)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBnGradBiasScale) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBnOptimize)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBnOptimize)) + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBnOptimize)) .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) .set_support_level(4); diff --git a/cinn/hlir/op/reduction_test.cc b/cinn/hlir/op/reduction_test.cc index 2bd36a7799..b9f9301ae3 100644 --- a/cinn/hlir/op/reduction_test.cc +++ b/cinn/hlir/op/reduction_test.cc @@ -50,48 +50,36 @@ using framework::shape_t; using framework::StrategyFunction; using runtime::cuda::CUDAModule; -void CpuReduceSum(const float* x, - std::vector* sum0, - std::vector* sum1, - const int n, - const int c, - const int h, - const int w) { - memset(sum0->data(), 0, sizeof(float) * c * w); - memset(sum1->data(), 0, sizeof(float) * c); - for (int idx = 0; idx < n; ++idx) { - for (int idy = 0; idy < c; ++idy) { - for (int idz = 0; idz < h; ++idz) { - for (int ida = 0; ida < w; ++ida) { - sum0->at(idy * w + ida) += x[idx * c * h * w + idy * h * w + idz * w + ida]; - sum1->at(idy) += x[idx * c * h * w + idy * h * w + idz * w + ida]; - } - } - } - } -} - -std::pair GenHostAndPtx(const std::vector& shape, +std::pair GenReduceCode(const std::vector& shape, const std::vector& dim, - const std::string func_name = "reduce_sum") { + const std::string& func_name, + bool keep_dim = false) { + // code gen auto reduce_sum = Operator::Get("reduce_sum"); auto strategy = Operator::GetAttrs("CINNStrategy")[reduce_sum]; + // input tensor std::vector shape_as_expr; for (auto value : shape) { shape_as_expr.emplace_back(value); } Placeholder X("X", shape_as_expr); + // set attrs NodeAttr attrs; - attrs.attr_store["dim"] = dim; + attrs.attr_store["dim"] = dim; + attrs.attr_store["keep_dim"] = keep_dim; std::vector inputs{X.tensor()}; std::vector out_type{Float(32)}; std::vector output_shape; - for (auto value : shape) { - if (std::find(dim.begin(), dim.end(), value) == dim.end()) { - output_shape.push_back(value); + for (int idx = 0; idx < shape.size(); ++idx) { + if (std::find(dim.begin(), dim.end(), idx) != dim.end()) { + if (keep_dim) { + output_shape.push_back(1); + } + } else { + output_shape.push_back(shape[idx]); } } @@ -105,14 +93,17 @@ std::pair GenHostAndPtx(const std::vector& shape, // the last element is a StageMap for (int i = 0; i < rets->size() - 1; i++) { Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); + if (!temp.as_tensor_ref()->buffer.defined()) { + inputs.push_back(temp.as_tensor_ref()); + } } + auto func = lang::LowerVec(func_name, rets.back(), inputs, {}, {}, nullptr, target); for (auto& f : func) { LOG(INFO) << "Test Strategy Codegen:\n" << f; } - Module::Builder builder("reduce_sum_0", target); + Module::Builder builder(func_name + "_builder", target); for (auto& f : func) { builder.AddFunction(f); } @@ -132,86 +123,140 @@ std::pair GenHostAndPtx(const std::vector& shape, auto source_code = codegen.Compile(builder.Build()); LOG(INFO) << "compiled code:\n\n\n" << source_code; - using runtime::cuda::CUDAModule; - backends::NVRTC_Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); + return std::pair(host_module, source_code); +} + +TEST(Operator, Operator_Reduction_Case_0) { + std::vector shape = {16, 16, 8, 16}; + std::vector dim = {2, 3}; - return std::pair(host_module, ptx); + GenReduceCode(shape, dim, "reduce_cast_0"); } -TEST(Operator, Operator_Reduction_Sum0) { - int n = 128, c = 128, h = 32, w = 32; - std::vector shape = {n, c, h, w}; +TEST(Operator, Operator_Reduction_Case_0_0) { + std::vector shape = {16, 16, 8, 16}; + std::vector dim = {2, 3}; + + GenReduceCode(shape, dim, "reduce_cast_0_0", true); +} + +TEST(Operator, Operator_Reduction_Case_1) { + std::vector shape = {16, 16, 32, 32}; + std::vector dim = {2, 3}; + + GenReduceCode(shape, dim, "reduce_cast_1"); +} + +TEST(Operator, Operator_Reduction_Case_1_1) { + std::vector shape = {16, 16, 32, 32}; + std::vector dim = {2, 3}; + + GenReduceCode(shape, dim, "reduce_cast_1_1", true); +} + +TEST(Operator, Operator_Reduction_Case_2) { + std::vector shape = {16, 16, 32, 32}; + std::vector dim = {1}; + + GenReduceCode(shape, dim, "reduce_cast_2", true); +} + +TEST(Operator, Operator_Reduction_Case_3) { + std::vector shape = {16, 16, 64, 64}; + std::vector dim = {1}; + + GenReduceCode(shape, dim, "reduce_cast_3"); +} + +TEST(Operator, Operator_Reduction_Case_4) { + std::vector shape = {16, 16, 16, 16}; std::vector dim = {0, 2, 3}; - auto module_ptx = GenHostAndPtx(shape, dim); - CUDA_CALL(cudaSetDevice(0)); - CUDAModule cuda_module(module_ptx.second, CUDAModule::Kind::PTX); - void* reduce_sum_kernel = cuda_module.GetFunction(0, "reduce_sum"); - CHECK(reduce_sum_kernel); - void* reduce_sum_1_kernel = cuda_module.GetFunction(0, "reduce_sum_1"); - CHECK(reduce_sum_1_kernel); + GenReduceCode(shape, dim, "reduce_cast_4"); +} - void* stream = nullptr; - backends::RuntimeSymbolRegistry::Global().RegisterFn("reduce_sum_kernel_ptr_", - reinterpret_cast(&reduce_sum_kernel)); - backends::RuntimeSymbolRegistry::Global().RegisterFn("reduce_sum_1_kernel_ptr_", - reinterpret_cast(&reduce_sum_1_kernel)); - backends::RuntimeSymbolRegistry::Global().RegisterVar("reduce_sum_kernel_stream_ptr_", stream); - backends::RuntimeSymbolRegistry::Global().RegisterVar("reduce_sum_1_kernel_stream_ptr_", stream); +TEST(Operator, Operator_Reduction_Case_4_4) { + std::vector shape = {16, 16, 16, 16}; + std::vector dim = {0, 2, 3}; - auto jit = backends::SimpleJIT::Create(); - jit->Link(module_ptx.first); + GenReduceCode(shape, dim, "reduce_cast_4_4", true); +} - auto fn_reduce_sum = jit->Lookup("reduce_sum"); - CHECK(fn_reduce_sum); - auto fn_reduce_sum_1 = jit->Lookup("reduce_sum_1"); - CHECK(fn_reduce_sum_1); +TEST(Operator, Operator_Reduction_Case_5) { + std::vector shape = {16, 16, 16, 16, 16, 32}; + std::vector dim = {1, 3, 5}; - auto func_0 = reinterpret_cast(fn_reduce_sum); - auto func_1 = reinterpret_cast(fn_reduce_sum_1); + GenReduceCode(shape, dim, "reduce_cast_5"); +} + +TEST(Operator, Operator_Reduction_Case_5_5) { + std::vector shape = {16, 16, 16, 16, 16, 32}; + std::vector dim = {1, 3, 5}; + + GenReduceCode(shape, dim, "reduce_cast_5_5", true); +} + +void CpuReduceSum(const float* x, + std::vector* sum0, + std::vector* sum1, + const int n, + const int c, + const int h, + const int w) { + memset(sum0->data(), 0, sizeof(float) * c * w); + memset(sum1->data(), 0, sizeof(float) * c); + for (int idx = 0; idx < n; ++idx) { + for (int idy = 0; idy < c; ++idy) { + for (int idz = 0; idz < h; ++idz) { + for (int ida = 0; ida < w; ++ida) { + sum0->at(idy * w + ida) += x[idx * c * h * w + idy * h * w + idz * w + ida]; + sum1->at(idy) += x[idx * c * h * w + idy * h * w + idz * w + ida]; + } + } + } + } +} + +TEST(Operator, Operator_Reduction_Case_6) { + int n = 128, c = 128, h = 32, w = 32; + std::vector shape = {n, c, h, w}; + std::vector dim = {0, 2, 3}; + + // get source code + auto source_code = GenReduceCode(shape, dim, "reduce_case_6").second; + + // nv jit compile to ptx + backends::NVRTC_Compiler compiler; + auto ptx = compiler(source_code); + CHECK(!ptx.empty()); + + // cuda_module load ptx + runtime::cuda::CUDAModule cuda_module(ptx, CUDAModule::Kind::PTX); srand(time(NULL)); + CUDA_CALL(cudaSetDevice(0)); + + // auto func_0 = reinterpret_cast(fn_reduce_sum); auto buffer_x = common::BufferBuilder(Float(32), {n, c, h, w}).set_random().Build(); - auto buffer_y = common::BufferBuilder(Float(32), {c, w}).set_random().Build(); auto buffer_z = common::BufferBuilder(Float(32), {c}).set_random().Build(); - void *dev_x = nullptr, *dev_y = nullptr, *dev_z = nullptr; + void *dev_x = nullptr, *dev_z = nullptr; CUDA_CALL(cudaMalloc(&dev_x, buffer_x->memory_size)); - CUDA_CALL(cudaMalloc(&dev_y, buffer_y->memory_size)); CUDA_CALL(cudaMalloc(&dev_z, buffer_z->memory_size)); - CUDA_CALL(cudaMemcpy(dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); - cinn_buffer_t _x; - cinn_buffer_t _y; - cinn_buffer_t _z; - - _x.memory = static_cast(dev_x); - _y.memory = static_cast(dev_y); - _z.memory = static_cast(dev_z); - - _x.memory_size = buffer_x->memory_size; - _y.memory_size = buffer_y->memory_size; - _z.memory_size = buffer_z->memory_size; + dim3 grid(128, 1, 1); + dim3 block(1024, 1, 1); + void* args[] = {&dev_x, &dev_z}; - cinn_pod_value_t x_arg(&_x), y_arg(&_y), z_arg(&_z); - cinn_pod_value_t args0[] = {x_arg, y_arg}; - cinn_pod_value_t args1[] = {x_arg, y_arg, z_arg}; - - func_0(args0, 2); - func_1(args1, 3); - - CUDA_CALL(cudaMemcpy(buffer_y->memory, dev_y, buffer_y->memory_size, cudaMemcpyDeviceToHost)); + cuda_module.LaunchKernel(0, "reduce_case_6", grid, block, args); CUDA_CALL(cudaMemcpy(buffer_z->memory, dev_z, buffer_z->memory_size, cudaMemcpyDeviceToHost)); std::vector sum0(c * w); std::vector sum1(c); CpuReduceSum(reinterpret_cast(buffer_x->memory), &sum0, &sum1, n, c, h, w); - std::vector, float*>> results = {{sum0, reinterpret_cast(buffer_y->memory)}, - {sum1, reinterpret_cast(buffer_z->memory)}}; + std::vector, float*>> results = {{sum1, reinterpret_cast(buffer_z->memory)}}; for (auto& res : results) { for (int idx = 0; idx < res.first.size(); ++idx) { ASSERT_LT(abs(res.first[idx] - res.second[idx]) / res.first[idx], 1e-4); @@ -219,33 +264,43 @@ TEST(Operator, Operator_Reduction_Sum0) { } CUDA_CALL(cudaFree(dev_x)); - CUDA_CALL(cudaFree(dev_y)); CUDA_CALL(cudaFree(dev_z)); } -TEST(Operator, Operator_Reduction_Sum1) { - int n = 32, c = 32, h = 128, w = 128; +TEST(Operator, Operator_Reduction_Case_7) { + int n = 32, c = 32, h = 16, w = 16; std::vector shape = {n, c, h, w}; std::vector dim = {0, 1}; - auto module_ptx = GenHostAndPtx(shape, dim, "reduce_sum_test_1"); + std::string func_name = "reduce_cast_7"; + // get source code + auto host_source = GenReduceCode(shape, dim, func_name); + + // compile to ptx + backends::NVRTC_Compiler compiler; + auto ptx = compiler(host_source.second); + CHECK(!ptx.empty()); + + // load ptx CUDA_CALL(cudaSetDevice(0)); - CUDAModule cuda_module(module_ptx.second, CUDAModule::Kind::PTX); - void* reduce_sum_test_1_kernel = cuda_module.GetFunction(0, "reduce_sum_test_1"); - CHECK(reduce_sum_test_1_kernel); + runtime::cuda::CUDAModule cuda_module(ptx, runtime::cuda::CUDAModule::Kind::PTX); + void* reduce_sum_kernel = cuda_module.GetFunction(0, func_name); + CHECK(reduce_sum_kernel); + // register cufunction and stream void* stream = nullptr; - backends::RuntimeSymbolRegistry::Global().RegisterFn("reduce_sum_test_1_kernel_ptr_", - reinterpret_cast(&reduce_sum_test_1_kernel)); - backends::RuntimeSymbolRegistry::Global().RegisterVar("reduce_sum_test_1_kernel_stream_ptr_", stream); + backends::RuntimeSymbolRegistry::Global().RegisterFn(func_name + "_kernel_ptr_", + reinterpret_cast(&reduce_sum_kernel)); + backends::RuntimeSymbolRegistry::Global().RegisterVar(func_name + "_kernel_stream_ptr_", stream); + // gen host code auto jit = backends::SimpleJIT::Create(); - jit->Link(module_ptx.first); + jit->Link(host_source.first); - auto fn_reduce_sum_test_1 = jit->Lookup("reduce_sum_test_1"); - CHECK(fn_reduce_sum_test_1); + auto fn_reduce_sum = jit->Lookup(func_name); + CHECK(fn_reduce_sum); - auto func_0 = reinterpret_cast(fn_reduce_sum_test_1); + auto func_0 = reinterpret_cast(fn_reduce_sum); srand(time(NULL)); auto buffer_x = common::BufferBuilder(Float(32), {n, c, h, w}).set_random().Build(); diff --git a/cinn/hlir/pe/reduction.cc b/cinn/hlir/pe/reduction.cc index 0fde9f33c4..9c1fff1ce1 100644 --- a/cinn/hlir/pe/reduction.cc +++ b/cinn/hlir/pe/reduction.cc @@ -202,15 +202,21 @@ Tensor ReduceMin( } std::vector WarpReduce(const ir::Tensor& A, - int last_reduce_dim_num, + const int last_reduce_dim_num, + const bool keep_dim, const std::string& reduce_type, const std::string& output_name) { - Expr lane(1); - for (int idx = A->shape.size() - 1; idx >= (A->shape.size() - last_reduce_dim_num); --idx) { - lane = lane * A->shape[idx].as_int32(); + // compute shape size without last reduce dimension. + int shape_size_without_reduce_dim = A->shape.size() - last_reduce_dim_num; + + // compute reduce dimension size. + Expr reduce_width(1); + for (int idx = shape_size_without_reduce_dim; idx < A->shape.size(); ++idx) { + reduce_width = reduce_width * A->shape[idx].as_int32(); } - std::vector tmp_shape(A->shape.begin(), A->shape.begin() + A->shape.size() - last_reduce_dim_num); + // comput tmp output shape. + std::vector tmp_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); tmp_shape.push_back(Expr(32)); auto tmp_out = Compute( tmp_shape, @@ -221,15 +227,23 @@ std::vector WarpReduce(const ir::Tensor& A, } CHECK_EQ(A->shape.size(), tmp_indexs.size()); Expr offset = common::IndiceToAbsOffset(A->shape, tmp_indexs); - return lang::CallExtern(reduce_type, {A, offset, lane}); + return lang::CallExtern(reduce_type, {A, offset, reduce_width}); }, UniqName(output_name + "_" + reduce_type)); - std::vector out_shape(A->shape.begin(), A->shape.begin() + A->shape.size() - last_reduce_dim_num); + // compute ouput shape. + std::vector out_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + for (int idx = 0; idx < last_reduce_dim_num && keep_dim; ++idx) { + out_shape.push_back(Expr(1)); + } + // if reduce on all dimension, the out_shape = {1}. + if (out_shape.size() == 0) { + out_shape.push_back(Expr(1)); + } auto out = Compute( out_shape, [=](const std::vector& indexs) -> Expr { - std::vector tmp_indexs(indexs); + std::vector tmp_indexs(indexs.begin(), indexs.begin() + shape_size_without_reduce_dim); tmp_indexs.push_back(Expr(0)); return tmp_out(tmp_indexs); }, @@ -238,34 +252,150 @@ std::vector WarpReduce(const ir::Tensor& A, return {out, tmp_out}; } -/** - * @brief find the max of array elements over the last dimension - * - * @param A The input Tensor - * @param output_name The name of the output Tensor - */ -std::vector WarpReduceMax(const ir::Tensor& A, int last_reduce_dim_num, const std::string& output_name) { - return WarpReduce(A, last_reduce_dim_num, "cinn_warp_reduce_max", output_name); +std::vector WarpReduceMax(const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim, + const std::string& output_name) { + return WarpReduce(A, last_reduce_dim_num, keep_dim, "cinn_warp_reduce_max", output_name); } -/** - * @brief compute the sum of array elements over the last dimension - * - * @param A The input Tensor - * @param output_name The name of the output Tensor - */ -std::vector WarpReduceSum(const ir::Tensor& A, int last_reduce_dim_num, const std::string& output_name) { - return WarpReduce(A, last_reduce_dim_num, "cinn_warp_reduce_sum", output_name); +std::vector WarpReduceSum(const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim, + const std::string& output_name) { + return WarpReduce(A, last_reduce_dim_num, keep_dim, "cinn_warp_reduce_sum", output_name); +} + +std::vector WarpReduceAvg(const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim, + const std::string& output_name) { + return WarpReduce(A, last_reduce_dim_num, keep_dim, "cinn_warp_reduce_avg", output_name); +} + +std::vector BlockReduceSumInternal(const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim, + const std::string& output_name) { + // compute shape size without last reduce dimension. + int shape_size_without_reduce_dim = A->shape.size() - last_reduce_dim_num; + + // compute reduce dimension size. + Expr reduce_width(1); + for (int idx = shape_size_without_reduce_dim; idx < A->shape.size(); ++idx) { + reduce_width = reduce_width * A->shape[idx].as_int32(); + } + + // compute tmp output shape. + std::vector tmp_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + tmp_shape.push_back(reduce_width); + + // compute the reduce dimension stride. + std::vector last_reduce_stride(last_reduce_dim_num, Expr(1)); + for (int idx = A->shape.size(), index = last_reduce_stride.size() - 2; index >= 0; --index) { + last_reduce_stride[index] = last_reduce_stride[index + 1] * A->shape[--idx]; + } + + auto tmp_out = Compute( + tmp_shape, + [=](const std::vector& indexs) -> Expr { + // comput index map from output to input. + auto last_index = indexs.back(); + std::vector input_indexs(indexs.begin(), indexs.begin() + indexs.size() - 1); + for (int idx = 0; idx < last_reduce_dim_num; ++idx) { + input_indexs.push_back(last_index / last_reduce_stride[idx]); + last_index = last_index % last_reduce_stride[idx]; + } + + // checkout input_indexs size equals input shape + CHECK_EQ(input_indexs.size(), A->shape.size()); + return lang::CallExtern("cinn_block_reduce_sum_internal", {A(input_indexs)}); + }, + UniqName(output_name + "_tmp")); + + // compute output shape. + std::vector out_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + for (int idx = 0; idx < last_reduce_dim_num && keep_dim; ++idx) { + out_shape.push_back(Expr(1)); + } + + // if reduce on all dimension, the out_shape = {1}. + if (out_shape.size() == 0) { + out_shape.push_back(Expr(1)); + } + auto out = Compute( + out_shape, + [=](const std::vector& indexs) -> Expr { + std::vector tmp_indexs(indexs.begin(), indexs.begin() + shape_size_without_reduce_dim); + tmp_indexs.push_back(Expr(0)); + return tmp_out(tmp_indexs); + }, + UniqName(output_name)); + + return {out, tmp_out}; } /** - * @brief compute the average of array elements over the last dimension + * @brief compute the sum of array elements over the last dimension with block reduce * - * @param A The input Tensor - * @param output_name The name of the output Tensor + * @param A The input Tensor. + * @param last_reduce_dim_num the number of last reduce dimension. + * @param keep_dim keep the output tensor shape size as input. + * @param output_name The name of the output Tensor. */ -std::vector WarpReduceAvg(const ir::Tensor& A, int last_reduce_dim_num, const std::string& output_name) { - return WarpReduce(A, last_reduce_dim_num, "cinn_warp_reduce_avg", output_name); +std::vector BlockReduceSum(const ir::Tensor& A, + const int last_reduce_dim_num, + const int block_size, + const bool keep_dim, + const std::string& output_name) { + // compute shape size without last reduce dimension. + int shape_size_without_reduce_dim = A->shape.size() - last_reduce_dim_num; + + // compute reduce dimension size. + Expr reduce_width(1); + for (int idx = shape_size_without_reduce_dim; idx < A->shape.size(); ++idx) { + reduce_width = reduce_width * A->shape[idx].as_int32(); + } + + // compute tmp output tensor shape + std::vector tmp_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + tmp_shape.push_back(Expr(block_size)); + auto tmp_out = Compute( + tmp_shape, + [=](const std::vector& indexs) -> Expr { + std::vector tmp_indexs(indexs.begin(), indexs.begin() + shape_size_without_reduce_dim); + for (int idx = 0; idx < last_reduce_dim_num; ++idx) { + tmp_indexs.push_back(Expr(0)); + } + // checkout input shape size equals tmp indexs size. + CHECK_EQ(A->shape.size(), tmp_indexs.size()); + // compute offset. + Expr offset = common::IndiceToAbsOffset(A->shape, tmp_indexs); + // call block reduce sum + return lang::CallExtern("cinn_block_reduce_sum", {A, offset, reduce_width}); + }, + UniqName(output_name + "_tmp")); + + // compute output tensor shape. + std::vector out_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + for (int idx = 0; idx < last_reduce_dim_num && keep_dim; ++idx) { + out_shape.push_back(Expr(1)); + } + // if reduce on all dimension, the out_shape = {1}. + if (out_shape.size() == 0) { + out_shape.push_back(Expr(1)); + } + auto out = Compute( + out_shape, + [=](const std::vector& indexs) -> Expr { + // compute input index + std::vector tmp_indexs(indexs.begin(), indexs.begin() + shape_size_without_reduce_dim); + tmp_indexs.push_back(Expr(0)); + return tmp_out(tmp_indexs); + }, + UniqName(output_name)); + + return {out, tmp_out}; } } // namespace pe diff --git a/cinn/hlir/pe/reduction.h b/cinn/hlir/pe/reduction.h index acf14f561f..72a1208905 100644 --- a/cinn/hlir/pe/reduction.h +++ b/cinn/hlir/pe/reduction.h @@ -102,36 +102,70 @@ ir::Tensor ReduceMin(const ir::Tensor& A, /** * @brief find the max of array elements over the last dimension * - * @param A The input Tensor - * @param last_reduce_dim_num the number of last reduce dimension - * @param output_name The name of the output Tensor + * @param A The input Tensor. + * @param last_reduce_dim_num the number of last reduce dimension. + * @param keep_dim keep the output tensor shape size as input. + * @param output_name The name of the output Tensor. */ std::vector WarpReduceMax(const ir::Tensor& A, - int last_reduce_dim_num, + const int last_reduce_dim_num, + const bool keep_dim = false, const std::string& output_name = "T_Warp_Reduce_Max_out"); /** * @brief compute the sum of array elements over the last dimension * - * @param A The input Tensor - * @param last_reduce_dim_num the number of last reduce dimension - * @param output_name The name of the output Tensor + * @param A The input Tensor. + * @param last_reduce_dim_num the number of last reduce dimension. + * @param keep_dim keep the output tensor shape size as input. + * @param output_name The name of the output Tensor. */ std::vector WarpReduceSum(const ir::Tensor& A, - int last_reduce_dim_num, + const int last_reduce_dim_num, + const bool keep_dim = false, const std::string& output_name = "T_Warp_Reduce_Sum_out"); /** * @brief compute the average of array elements over the last dimension * - * @param A The input Tensor - * @param last_reduce_dim_num the number of last reduce dimension - * @param output_name The name of the output Tensor + * @param A The input Tensor. + * @param last_reduce_dim_num the number of last reduce dimension. + * @param keep_dim keep the output tensor shape size as input. + * @param output_name The name of the output Tensor. */ std::vector WarpReduceAvg(const ir::Tensor& A, - int last_reduce_dim_num, + const int last_reduce_dim_num, + const bool keep_dim = false, const std::string& output_name = "T_Warp_Reduce_Avg_out"); +/** + * @brief compute the sum of array elements over the last dimension with block reduce. + * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * + * @param A The input Tensor. + * @param last_reduce_dim_num the number of last reduce dimension. + * @param keep_dim keep the output tensor shape size as input. + * @param output_name The name of the output Tensor. + */ +std::vector BlockReduceSumInternal(const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Sum_Internal_out"); + +/** + * @brief compute the sum of array elements over the last dimension with block reduce + * + * @param A The input Tensor. + * @param last_reduce_dim_num the number of last reduce dimension. + * @param keep_dim keep the output tensor shape size as input. + * @param output_name The name of the output Tensor. + */ +std::vector BlockReduceSum(const ir::Tensor& A, + const int last_reduce_dim_num, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Sum_out"); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/schedule.cc b/cinn/hlir/pe/schedule.cc index 4ac2213fb9..7fd1981486 100644 --- a/cinn/hlir/pe/schedule.cc +++ b/cinn/hlir/pe/schedule.cc @@ -343,40 +343,58 @@ int GetBlockBindAxis(const std::vector &shape, const int thread_axis) return block_axis; } -void CudaScheduleReduce(poly::StageMap stages, ir::Tensor output, const common::Target &target) { - // find the dimension to bind threadIdx.x. - auto thread_axis = GetThreadBindAxis(output->shape); - // use the max dimension to bind blockIdx.x - auto block_axis = GetBlockBindAxis(output->shape, thread_axis); +void CudaScheduleReduce(poly::StageMap stages, + ir::Tensor output, + int last_dimension_num, + const common::Target &target) { + int parallel_thread_num = 1; + for (int idx = output->shape.size() - 1; idx >= static_cast(output->shape.size()) - last_dimension_num; --idx) { + parallel_thread_num *= output->shape[idx].as_int32(); + } - if (block_axis < thread_axis) { - stages[output]->Bind(block_axis, "blockIdx.x"); + int index = output->shape.size() - last_dimension_num; + for (int idx = output->shape.size() - last_dimension_num; idx < static_cast(output->shape.size()) - 1; ++idx) { + stages[output]->Fuse(index, index + 1); } - if (output->shape[thread_axis].as_int32() > 512) { - stages[output]->Split(thread_axis, 512); - if (block_axis == thread_axis) { - stages[output]->Bind(thread_axis, "blockIdx.x"); - } - stages[output]->Bind(thread_axis + 1, "threadIdx.x"); + int max_block_size = 1024; + if (parallel_thread_num > max_block_size) { + stages[output]->Split(index, max_block_size); + stages[output]->Bind(index + 1, "threadIdx.x"); } else { - stages[output]->Bind(thread_axis, "threadIdx.x"); + stages[output]->Bind(index, "threadIdx.x"); + } + + for (int idx = 0; idx < index - 1; ++idx) { + stages[output]->Fuse(0, 1); + } + + if (index > 0) { + stages[output]->Bind(0, "blockIdx.x"); } } void CudaScheduleWarpReduce(poly::StageMap stages, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target) { int sum_out_dim = 1; - for (int idx = 0; idx < out->shape.size() - 1; ++idx) { + for (int idx = 0; idx < static_cast(tmp_out->shape.size()) - 2; ++idx) { stages[out]->Fuse(0, 1); stages[tmp_out]->Fuse(0, 1); sum_out_dim *= out->shape[idx].as_int32(); } sum_out_dim *= out->shape.back().as_int32(); - if (sum_out_dim < 16) { + // out_shape = {1} tmp_shape = {32} + if (tmp_out->shape.size() == 1) { + stages[out]->Split(0, 1); + stages[tmp_out]->Split(0, tmp_out->shape[0].as_int32()); + } + + if (sum_out_dim <= 16) { stages[out]->Bind(0, "threadIdx.y"); + + stages[tmp_out]->Bind(0, "threadIdx.y"); stages[tmp_out]->Bind(1, "threadIdx.x"); - stages[tmp_out]->ComputeAt2(stages[out], 0); + stages[tmp_out]->SimpleComputeAt(stages[out], 0); stages[tmp_out]->SetBuffer("local"); } else { stages[out]->Split(0, 8); @@ -385,11 +403,64 @@ void CudaScheduleWarpReduce(poly::StageMap stages, ir::Tensor tmp_out, ir::Tenso stages[tmp_out]->Split(0, 8); stages[tmp_out]->Bind(2, "threadIdx.x"); - stages[tmp_out]->ComputeAt2(stages[out], 1); + stages[tmp_out]->SimpleComputeAt(stages[out], 1); stages[tmp_out]->SetBuffer("local"); } } +void CudaScheduleBlockReduceInternal(poly::StageMap stages, + ir::Tensor tmp_out, + ir::Tensor out, + const common::Target &target) { + for (int idx = 0; idx < static_cast(tmp_out->shape.size()) - 2; ++idx) { + stages[tmp_out]->Fuse(0, 1); + stages[out]->Fuse(0, 1); + } + + if (tmp_out->shape.size() == 1) { + stages[out]->Split(0, 1); + stages[tmp_out]->Split(0, tmp_out->shape[0].as_int32()); + } + + stages[out]->Bind(0, "blockIdx.x"); + + stages[tmp_out]->Bind(0, "blockIdx.x"); + stages[tmp_out]->Bind(1, "threadIdx.x"); + stages[tmp_out]->SimpleComputeAt(stages[out], 0); + stages[tmp_out]->SetBuffer("local"); +} + +void CudaScheduleBlockReduce(poly::StageMap stages, + ir::Tensor reduce_tmp_out, + ir::Tensor tmp_out, + ir::Tensor out, + const common::Target &target) { + int output_shape_size_without_reduce = tmp_out->shape.size() - 1; + // fuse last parallel dimension + for (int idx = 0; idx < reduce_tmp_out->shape.size() - tmp_out->shape.size(); ++idx) { + stages[reduce_tmp_out]->Fuse(output_shape_size_without_reduce, output_shape_size_without_reduce + 1); + } + + // fuse parallel dimension + for (int idx = 0; idx < output_shape_size_without_reduce - 1; ++idx) { + stages[out]->Fuse(0, 1); + stages[tmp_out]->Fuse(0, 1); + stages[reduce_tmp_out]->Fuse(0, 1); + } + + stages[reduce_tmp_out]->Bind(0, "blockIdx.x"); + stages[reduce_tmp_out]->Bind(1, "threadIdx.x"); + stages[reduce_tmp_out]->SetBuffer("local"); + stages[reduce_tmp_out]->SimpleComputeAt(stages[tmp_out], 0); + + stages[tmp_out]->Bind(0, "blockIdx.x"); + stages[tmp_out]->Bind(1, "threadIdx.x"); + stages[tmp_out]->SetBuffer("local"); + stages[tmp_out]->SimpleComputeAt(stages[out], 0); + + stages[out]->Bind(0, "blockIdx.x"); +} + void SoftmaxScheduleCPU(poly::StageMap stage, const ir::Tensor &output, const ir::Tensor &temp, int axis) { if (axis == -1) { axis += output->shape.size(); diff --git a/cinn/hlir/pe/schedule.h b/cinn/hlir/pe/schedule.h index 9e1a814b2c..792c8d1c60 100644 --- a/cinn/hlir/pe/schedule.h +++ b/cinn/hlir/pe/schedule.h @@ -156,10 +156,18 @@ void CudaScheduleMul(poly::StageMap stages, const std::vector &output_shape, const common::Target &target); -void CudaScheduleReduce(poly::StageMap stages, ir::Tensor output, const common::Target &target); +void CudaScheduleReduce(poly::StageMap stages, ir::Tensor output, int last_dimension_num, const common::Target &target); void CudaScheduleWarpReduce(poly::StageMap stages, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target); +void CudaScheduleBlockReduceInternal(poly::StageMap stages, + ir::Tensor tmp_out, + ir::Tensor out, + const common::Target &target); + +void CudaScheduleBlockReduce( + poly::StageMap stages, ir::Tensor reduce_tmp_out, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target); + void CudaScheduleDepthwiseConv(poly::StageMap stages, ir::Tensor &output, const common::Target &target); void CudaScheduleConv(poly::StageMap stages, diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh old mode 100755 new mode 100644 index 62608b7e18..cdd86b1866 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -2,7 +2,7 @@ * \file This file contains all the intrinsics available to be used in CUDA code generated by CodeGen. */ -#define FN(x) cinn_nvgpu_ ## x ## _fp32 +#define FN(x) cinn_nvgpu_##x##_fp32 // NOTE Due to function override, we don't need to use type (such as '_fp32') as the suffix of function's name. __device__ inline float FN(sin)(float x) { return sin(x); } __device__ inline float FN(cos)(float x) { return cos(x); } @@ -37,6 +37,18 @@ __device__ inline float FN(min)(float a, float b) { return min(a, b); } #undef FN +__device__ inline float cinn_warp_shuffle_sum_interal(const float value) { + float sumv = value; + unsigned int mask = __activemask(); + sumv += __shfl_down_sync(mask, sumv, 16, 32); + sumv += __shfl_down_sync(mask, sumv, 8, 32); + sumv += __shfl_down_sync(mask, sumv, 4, 32); + sumv += __shfl_down_sync(mask, sumv, 2, 32); + sumv += __shfl_down_sync(mask, sumv, 1, 32); + sumv = __shfl_sync(mask, sumv, 0, 32); + return sumv; +} + __device__ inline float cinn_warp_reduce_max(const float *buf, int offset, int extend) { float maxv = -3.402823e+38f; for (int i = threadIdx.x; i < extend; i += 32) { @@ -58,15 +70,8 @@ __device__ inline float cinn_warp_reduce_avg(const float *buf, int offset, int e for (int i = threadIdx.x; i < extend; i += 32) { sumv += buf[offset + i] / (float)extend; } - unsigned int mask; - mask = __activemask(); - sumv += __shfl_down_sync(mask, sumv, 16, 32); - sumv += __shfl_down_sync(mask, sumv, 8, 32); - sumv += __shfl_down_sync(mask, sumv, 4, 32); - sumv += __shfl_down_sync(mask, sumv, 2, 32); - sumv += __shfl_down_sync(mask, sumv, 1, 32); - sumv = __shfl_sync(mask, sumv , 0, 32); - return sumv; + + return cinn_warp_shuffle_sum_interal(sumv); } __device__ inline float cinn_warp_reduce_sum(const float *buf, int offset, int extend) { @@ -74,16 +79,39 @@ __device__ inline float cinn_warp_reduce_sum(const float *buf, int offset, int e for (int i = threadIdx.x; i < extend; i += 32) { sumv += buf[offset + i]; } - unsigned int mask; - mask = __activemask(); - sumv += __shfl_down_sync(mask, sumv, 16, 32); - sumv += __shfl_down_sync(mask, sumv, 8, 32); - sumv += __shfl_down_sync(mask, sumv, 4, 32); - sumv += __shfl_down_sync(mask, sumv, 2, 32); - sumv += __shfl_down_sync(mask, sumv, 1, 32); - sumv = __shfl_sync(mask, sumv , 0, 32); - return sumv; + return cinn_warp_shuffle_sum_interal(sumv); } +__device__ inline float cinn_block_reduce_sum_internal(const float buf) { + int warp_id = threadIdx.x / 32; + __shared__ float tmp[32]; + if (warp_id == 0) { + tmp[threadIdx.x] = 0.0f; + } + float sum = cinn_warp_shuffle_sum_interal(buf); + if (blockDim.x <= 32) { + return sum; + } + __syncthreads(); + if (threadIdx.x % 32 == 0) { + tmp[warp_id] = sum; + } + __syncthreads(); + if (warp_id == 0) { + sum = tmp[threadIdx.x]; + sum = cinn_warp_shuffle_sum_interal(sum); + if (threadIdx.x == 0) { + tmp[0] = sum; + } + } + __syncthreads(); + return tmp[0]; +} - +__device__ inline float cinn_block_reduce_sum(const float *buf, int offset, int extend) { + float sumv = 0; + for (int i = threadIdx.x; i < extend; i += blockDim.x) { + sumv += buf[offset + i]; + } + return cinn_block_reduce_sum_internal(sumv); +} diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index cd25c5642a..b0f0cef2a8 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -84,6 +84,18 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { .AddInputType() .End(); + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_sum_internal, target) + .SetRetType() + .AddInputType() + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_sum, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + return true; }