diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index 3751ee0a03db44..e8e5a1ef29aedf 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -54,7 +54,6 @@ std::set OpsCanSkipedFakeAllocInStaticBuild = { "nop"}; std::set StaticBuildBlackList = { - "batch_norm" /*: to handle reserve_space output*/, "cinn_instruction_run" /*: to handle subgraph infermeta*/, "cinn_launch" /*: to handle subgraph infermeta*/, "run_program" /*: to handle scope output*/, @@ -206,6 +205,14 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op, } } + if (op_type == "batch_norm" && parameter_name == "ReserveSpace") { + if (dynamic_cast(&op)->kernel_type()->place_ == + phi::CPUPlace()) { + VLOG(2) << "Skip fake initialization for: " << parameter_name; + return false; + } + } + if (op_type == "coalesce_tensor" && parameter_name == "Output") { VLOG(2) << "Skip fake initialization for: " << parameter_name; return false; @@ -250,6 +257,12 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op, } } + if ((op_type == "flatten" || op_type == "flatten_contiguous_range") && + parameter_name == "XShape") { + VLOG(2) << "Skip fake initialization for: " << parameter_name; + return false; + } + if (op_type == "segment_pool" && parameter_name == "SummedIds") { return op.Attr("pooltype") == "MEAN" && dynamic_cast(&op) @@ -856,6 +869,8 @@ void FakeInitializeOutputsForFunctionKernel( dtype = InferDTypeFromAttr(op, runtime_ctx, "dtype"); } else if (op_type == "bincount" || op_type == "reduce_sum_grad") { dtype = GetInputDType(runtime_ctx, "X"); + } else if (op_type == "dequantize_linear") { + dtype = GetInputDType(runtime_ctx, "Scale"); } else if (op_type == "lamb") { bool multi_precision = op.Attr("multi_precision"); dtype = GetInputDType(runtime_ctx, "Moment1"); diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu deleted file mode 100644 index 1fa7ff1826b071..00000000000000 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "paddle/fluid/framework/data_layout.h" -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/fused/fused_bn_add_activation_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/flags.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/norm_utils.h" - -PHI_DECLARE_bool(cudnn_batchnorm_spatial_persistent); - -namespace paddle { -namespace operators { -template -using CudnnDataType = platform::CudnnDataType; -template -using BatchNormParamType = typename CudnnDataType::BatchNormParamType; - -template -class FusedBatchNormAddActKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_bn_add_activation operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("It must use CUDAPlace.")); - auto &dev_ctx = ctx.template device_context(); - double epsilon = static_cast(ctx.Attr("epsilon")); - float momentum = ctx.Attr("momentum"); - std::string act_type = ctx.Attr("act_type"); - - if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { - LOG(ERROR) << "Provided epsilon is smaller than " - << "CUDNN_BN_MIN_EPSILON. Setting it to " - << "CUDNN_BN_MIN_EPSILON instead."; - } - epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); - - // Get the size for each dimension. - // NHWC [batch_size, in_height, in_width, in_channels] - const auto *x = ctx.Input("X"); - const auto *z = ctx.Input("Z"); - const auto &in_dims = x->dims(); - - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - dev_ctx.Alloc>( - mean_out, mean_out->numel() * sizeof(BatchNormParamType)); - dev_ctx.Alloc>( - variance_out, variance_out->numel() * sizeof(BatchNormParamType)); - - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_variance = ctx.Output("SavedVariance"); - dev_ctx.Alloc>( - saved_mean, saved_mean->numel() * sizeof(BatchNormParamType)); - dev_ctx.Alloc>( - saved_variance, - saved_variance->numel() * sizeof(BatchNormParamType)); - - auto *y = ctx.Output("Y"); - dev_ctx.Alloc(y, y->numel() * sizeof(T)); - - int N, C, H, W, D; - const DataLayout data_layout = DataLayout::kNHWC; - phi::funcs::ExtractNCWHD(in_dims, data_layout, &N, &C, &H, &W, &D); - - // ------------------- cudnn descriptors --------------------- - auto handle = dev_ctx.cudnn_handle(); - cudnnTensorDescriptor_t data_desc_; - cudnnTensorDescriptor_t bn_param_desc_; - cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); - - std::vector dims = {N, C, H, W, D}; - std::vector strides = {H * W * D * C, 1, W * D * C, D * C, C}; - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, - CudnnDataType::type, - in_dims.size() > 3 ? in_dims.size() : 4, - dims.data(), - strides.data())); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); - - double this_factor = 1. - momentum; - cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; - platform::ScopedActivationDescriptor scope_act_desc; - cudnnActivationDescriptor_t activation_desc_ = - scope_act_desc.descriptor(act_type); - size_t workspace_size = 0; - size_t reserve_space_size = 0; - void *reserve_space_ptr = nullptr; - void *workspace_ptr = nullptr; - phi::DenseTensor workspace_tensor; - // Create reserve space and workspace for batch norm. - // Create tensor for each batchnorm op, it will be used in the - // backward. Thus this tensor shouldn't be temp. - auto *reserve_space = ctx.Output("ReserveSpace"); - PADDLE_ENFORCE_NOT_NULL( - reserve_space, - platform::errors::NotFound( - "The argument ReserveSpace of batch_norm op is not found.")); - - // --------------- cudnn batchnorm workspace --------------- - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload:: - cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*xDesc=*/data_desc_, - /*zDesc=*/data_desc_, - /*yDesc=*/data_desc_, - /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, - /*activationDesc=*/activation_desc_, - /*sizeInBytes=*/&workspace_size)); - - // -------------- cudnn batchnorm reserve space -------------- - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetBatchNormalizationTrainingExReserveSpaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*activationDesc=*/activation_desc_, - /*xDesc=*/data_desc_, - /*sizeInBytes=*/&reserve_space_size)); - - reserve_space->Resize({static_cast( - (reserve_space_size + phi::SizeOf(x->dtype()) - 1) / - phi::SizeOf(x->dtype()))}); - reserve_space_ptr = - dev_ctx.Alloc(reserve_space, reserve_space->numel() * sizeof(T)); - workspace_tensor.Resize( - {static_cast((workspace_size + phi::SizeOf(x->dtype()) - 1) / - phi::SizeOf(x->dtype()))}); - workspace_ptr = dev_ctx.Alloc(&workspace_tensor, - workspace_tensor.numel() * sizeof(T)); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnBatchNormalizationForwardTrainingEx( - handle, - mode_, - bnOps_, - CudnnDataType::kOne(), - CudnnDataType::kZero(), - data_desc_, - x->template data(), - data_desc_, - z->template data(), - data_desc_, - y->template data(), - bn_param_desc_, - scale->template data>(), - bias->template data>(), - this_factor, - dev_ctx.template Alloc>( - mean_out, mean_out->numel() * sizeof(BatchNormParamType)), - dev_ctx.template Alloc>( - variance_out, - variance_out->numel() * sizeof(BatchNormParamType)), - epsilon, - dev_ctx.template Alloc>( - saved_mean, - saved_mean->numel() * sizeof(BatchNormParamType)), - dev_ctx.template Alloc>( - saved_variance, - saved_variance->numel() * sizeof(BatchNormParamType)), - activation_desc_, - workspace_ptr, - workspace_size, - reserve_space_ptr, - reserve_space_size)); - - // clean when exit. - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); - } -}; - -template -class FusedBatchNormAddActGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_bn_add_activation operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("It must use CUDAPlace.")); - double epsilon = static_cast(ctx.Attr("epsilon")); - std::string act_type = ctx.Attr("act_type"); - - const auto *x = ctx.Input("X"); - const auto *y = ctx.Input("Y"); - const auto *d_y = ctx.Input(framework::GradVarName("Y")); - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - const auto *reserve_space = ctx.Input("ReserveSpace"); - - auto &dev_ctx = ctx.template device_context(); - const auto &in_dims = x->dims(); - - int N, C, H, W, D; - const DataLayout data_layout = DataLayout::kNHWC; - phi::funcs::ExtractNCWHD(in_dims, data_layout, &N, &C, &H, &W, &D); - - // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_z = ctx.Output(framework::GradVarName("Z")); - auto *d_scale = - ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - d_x->mutable_data(ctx.GetPlace()); - d_z->mutable_data(ctx.GetPlace()); - PADDLE_ENFORCE_EQ( - d_scale && d_bias, - true, - platform::errors::PreconditionNotMet( - "Both the scale grad and the bias grad must not be null.")); - d_scale->mutable_data>(ctx.GetPlace()); - d_bias->mutable_data>(ctx.GetPlace()); - PADDLE_ENFORCE_EQ(scale->dims().size(), - 1UL, - platform::errors::PreconditionNotMet( - "The scale only has one dimension.")); - PADDLE_ENFORCE_EQ( - scale->dims()[0], - C, - platform::errors::PreconditionNotMet( - "The size of scale is equal to the channel of Input(X).")); - - std::vector dims = {N, C, H, W, D}; - std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; - // ------------------- cudnn descriptors --------------------- - cudnnTensorDescriptor_t data_desc_; - cudnnTensorDescriptor_t bn_param_desc_; - cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); - if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { - LOG(ERROR) << "Provided epsilon is smaller than " - << "CUDNN_BN_MIN_EPSILON. Setting it to " - << "CUDNN_BN_MIN_EPSILON instead."; - } - epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, - CudnnDataType::type, - in_dims.size() > 3 ? in_dims.size() : 4, - dims.data(), - strides.data())); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); - - const auto *saved_mean = ctx.Input("SavedMean"); - const auto *saved_var = ctx.Input("SavedVariance"); - const auto *saved_mean_data = - saved_mean->template data>(); - const auto *saved_var_data = - saved_var->template data>(); - - size_t workspace_size = 0; - void *workspace_ptr = nullptr; - phi::DenseTensor workspace_tensor; - auto reserve_space_size = reserve_space->memory_size(); - cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; - platform::ScopedActivationDescriptor scope_act_desc; - cudnnActivationDescriptor_t activation_desc_ = - scope_act_desc.descriptor(act_type); - // --------------- cudnn batchnorm workspace --------------- - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetBatchNormalizationBackwardExWorkspaceSize( - /*handle=*/dev_ctx.cudnn_handle(), - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*xDesc=*/data_desc_, - /*yDesc=*/data_desc_, - /*dyDesc=*/data_desc_, - /*dzDesc=*/data_desc_, - /*dxDesc=*/data_desc_, - /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, - /*activationDesc=*/activation_desc_, - /*sizeInBytes=*/&workspace_size)); - - workspace_ptr = workspace_tensor.mutable_data( - ctx.GetPlace(), x->dtype(), workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnBatchNormalizationBackwardEx( - /*handle=*/dev_ctx.cudnn_handle(), - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*alphaDataDiff=*/CudnnDataType::kOne(), - /*betaDataDiff=*/CudnnDataType::kZero(), - /*alphaParamDiff=*/CudnnDataType::kOne(), - /*betaParamDiff=*/CudnnDataType::kZero(), - /*xDesc=*/data_desc_, - /*xData=*/x->template data(), - /*yDesc=*/data_desc_, - /*yData=*/y->template data(), - /*dyDesc=*/data_desc_, - /*dyData=*/d_y->template data(), - /*dzDesc=*/data_desc_, - /*dzData=*/d_z->template data(), - /*dxDesc=*/data_desc_, - /*dxData=*/d_x->template data(), - /*dBnScaleBiasDesc=*/bn_param_desc_, - /*bnScaleData=*/scale->template data>(), - /*bnBiasData=*/bias->template data>(), - /*dBnScaleData=*/d_scale->template data>(), - /*dBnBiasData=*/d_bias->template data>(), - /*epsilon=*/epsilon, - /*savedMean=*/saved_mean_data, - /*savedInvVariance=*/saved_var_data, - /*activationDesmc=*/activation_desc_, - /*workspace=*/workspace_ptr, - /*workSpaceSizeInBytes=*/workspace_size, - /*reserveSpace=*/const_cast(reserve_space->template data()), - /*reserveSpaceSizeInBytes=*/reserve_space_size)); - - // clean when exit. - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(fused_bn_add_activation, - GPU, - ALL_LAYOUT, - ops::FusedBatchNormAddActKernel, - plat::float16) {} -PD_REGISTER_STRUCT_KERNEL(fused_bn_add_activation_grad, - GPU, - ALL_LAYOUT, - ops::FusedBatchNormAddActGradKernel, - plat::float16) {} diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h index 215ccfdde5e026..82967b043d89e8 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h @@ -89,17 +89,5 @@ class FusedBatchNormAddActOpInferVarType } }; -template -class FusedBatchNormAddActKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override; -}; - -template -class FusedBatchNormAddActGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override; -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index e2c2eb7768e1bd..c0ef288b5134bf 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -239,11 +239,3 @@ REGISTER_OPERATOR( ops::QuantizeLinearOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL(dequantize_linear, - CPU, - ALL_LAYOUT, - ops::DeQuantizeLinearKernel, - float, - int8_t, - double) {} diff --git a/paddle/fluid/operators/quantize_linear_op.cu b/paddle/fluid/operators/quantize_linear_op.cu index f0d6523d054c29..8bcbc1107e9d13 100644 --- a/paddle/fluid/operators/quantize_linear_op.cu +++ b/paddle/fluid/operators/quantize_linear_op.cu @@ -123,15 +123,6 @@ template struct ChannelDequantizeFunctorV2; namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL(dequantize_linear, - GPU, - ALL_LAYOUT, - ops::DeQuantizeLinearKernel, - float, - float16, - int8_t, - double) {} - PD_REGISTER_STRUCT_KERNEL(quantize_linear, GPU, ALL_LAYOUT, diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index 276d1507a4aef8..d6c3b3d2e50ae8 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -130,74 +130,5 @@ class QuantizeLinearKernel : public framework::OpKernel { } }; -template -class DeQuantizeLinearKernel : public framework::OpKernel { - public: - template - void ComputeImpl(const framework::ExecutionContext& context) const { - auto& dev_ctx = context.template device_context(); - auto* in = context.Input("X"); - - auto in_tmp = phi::Cast( - static_cast::TYPE&>(dev_ctx), - *in, - phi::CppTypeToDataType::Type()); - - auto* scale = context.Input("Scale"); - auto* out = context.Output("Y"); - int bit_length = context.Attr("bit_length"); - auto quant_axis = context.Attr("quant_axis"); - dev_ctx.template Alloc(out, out->numel() * sizeof(D)); - bool only_observer = context.Attr("only_observer"); - - if (only_observer) { - framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); - return; - } - - if (quant_axis < 0) { - float max_range = (std::pow(2, bit_length - 1) - 1); - DequantizeFunctor()( - dev_ctx, &in_tmp, scale, static_cast(max_range), out); - } else { - PADDLE_ENFORCE_EQ( - scale->numel(), - in_tmp.dims()[quant_axis], - platform::errors::PreconditionNotMet( - "The number of first scale values must be the same with " - "quant_axis dimension value of Input(X) when the `scale` has " - "only one element, but %ld != %ld here.", - scale->numel(), - in_tmp.dims()[quant_axis])); - int max_range = (std::pow(2, bit_length - 1) - 1); - - ChannelDequantizeFunctorV2()( - dev_ctx, &in_tmp, scale, static_cast(max_range), quant_axis, out); - } - } - - void Compute(const framework::ExecutionContext& context) const override { - auto* scale = context.Input("Scale"); - switch (scale->dtype()) { - case phi::DataType::FLOAT64: - ComputeImpl(context); - break; - case phi::DataType::FLOAT32: - ComputeImpl(context); - break; - case phi::DataType::FLOAT16: - ComputeImpl(context); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "In DeQuantizeLinearKernel, " - "data type %d for scale/output is not supported ", - scale->dtype())); - break; - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/cpu/quantize_linear_kernel.cc b/paddle/phi/kernels/cpu/quantize_linear_kernel.cc new file mode 100644 index 00000000000000..a7f3954407a526 --- /dev/null +++ b/paddle/phi/kernels/cpu/quantize_linear_kernel.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/kernels/quantize_linear_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/impl/quantize_linear_impl.h" + +namespace phi { + +template +struct DequantizeFunctor { + void operator()(const phi::CPUContext& dev_ctx, + const phi::DenseTensor* in, + const phi::DenseTensor* scale, + T max_range, + phi::DenseTensor* out) { + auto in_e = phi::EigenVector::Flatten(*in); + const T* scale_factor = scale->data(); + auto out_e = phi::EigenVector::Flatten(*out); + + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * scale_factor[0] / max_range; + } +}; + +template +struct ChannelDequantizeFunctorV2 { + void operator()(const phi::CPUContext& dev_ctx, + const phi::DenseTensor* in, + const phi::DenseTensor* scale, + T max_range, + const int quant_axis, + phi::DenseTensor* out) { + // Dequant op is before quantized op + // Dequantize the weight of quantized op + auto in_dims = in->dims(); + const int64_t channel = in_dims[quant_axis]; + const T* scale_factor = scale->data(); + if (quant_axis == 0) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_factor[i]; + phi::DenseTensor one_channel_in = in->Slice(i, i + 1); + phi::DenseTensor one_channel_out = out->Slice(i, i + 1); + auto in_e = phi::EigenVector::Flatten(one_channel_in); + auto out_e = phi::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s / max_range; + } + } else if (quant_axis == 1) { + int64_t out_iter = 1; + for (int i = 0; i < quant_axis; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_factor[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s / max_range; + ++cur_in; + ++cur_out; + } + } + } + } + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct ChannelDequantizeFunctorV2; +template struct ChannelDequantizeFunctorV2; +template struct ChannelDequantizeFunctorV2; + +} // namespace phi + +PD_REGISTER_KERNEL(dequantize_linear, + CPU, + ALL_LAYOUT, + phi::DeQuantizeLinearKernel, + float, + int8_t, + double) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h b/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h new file mode 100644 index 00000000000000..c98a5f69ae0d6a --- /dev/null +++ b/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedBatchNormAddActGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &y_grad, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + const DenseTensor &reserve_space, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *z_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fused_bn_add_activation_kernel.h b/paddle/phi/kernels/fused_bn_add_activation_kernel.h new file mode 100644 index 00000000000000..9d4f468a261ee6 --- /dev/null +++ b/paddle/phi/kernels/fused_bn_add_activation_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedBatchNormAddActKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &z, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &mean, + const DenseTensor &variance, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu new file mode 100644 index 00000000000000..e19b468b54a355 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu @@ -0,0 +1,223 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" +#include "paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h" + +PHI_DECLARE_bool(cudnn_batchnorm_spatial_persistent); + +namespace phi { +namespace fusion { + +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; +template +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; + +template +void FusedBatchNormAddActGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &y_grad, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + const DenseTensor &reserve_space, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *z_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad) { +#if CUDNN_VERSION < 7401 + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_bn_add_activation operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif + bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; + PADDLE_ENFORCE_EQ(is_gpu_place, + true, + phi::errors::PreconditionNotMet("It must use CUDAPlace.")); + double epsilon1 = static_cast(epsilon); + + const auto *x_ptr = &x; + const auto *y_ptr = &y; + const auto *d_y = &y_grad; + const auto *scale_ptr = &scale; + const auto *bias_ptr = &bias; + const auto *reserve_space_ptr = &reserve_space; + + const auto &in_dims = x_ptr->dims(); + + int N, C, H, W, D; + const DataLayout data_layout = DataLayout::kNHWC; + phi::funcs::ExtractNCWHD(in_dims, data_layout, &N, &C, &H, &W, &D); + + // init output + auto *d_x = x_grad; + auto *d_z = z_grad; + auto *d_scale = scale_grad; + auto *d_bias = bias_grad; + + dev_ctx.template Alloc(d_x); + dev_ctx.template Alloc(d_z); + + PADDLE_ENFORCE_EQ( + d_scale && d_bias, + true, + phi::errors::PreconditionNotMet( + "Both the scale grad and the bias grad must not be null.")); + + dev_ctx.template Alloc>(d_scale); + dev_ctx.template Alloc>(d_bias); + + PADDLE_ENFORCE_EQ( + scale_ptr->dims().size(), + 1UL, + phi::errors::PreconditionNotMet("The scale only has one dimension.")); + PADDLE_ENFORCE_EQ( + scale_ptr->dims()[0], + C, + phi::errors::PreconditionNotMet( + "The size of scale is equal to the channel of Input(X).")); + + std::vector dims = {N, C, H, W, D}; + std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon1 <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon1 = std::max(epsilon1, CUDNN_BN_MIN_EPSILON); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor( + data_desc_, + CudnnDataType::type, + in_dims.size() > 3 ? in_dims.size() : 4, + dims.data(), + strides.data())); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + const auto *saved_mean_ptr = &saved_mean; + const auto *saved_var_ptr = &saved_variance; + const auto *saved_mean_data = + saved_mean_ptr->template data>(); + const auto *saved_var_data = + saved_var_ptr->template data>(); + + size_t workspace_size = 0; + void *workspace_ptr = nullptr; + phi::DenseTensor workspace_tensor; + auto reserve_space_size = reserve_space_ptr->memory_size(); + cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; + phi::backends::gpu::ScopedActivationDescriptor scope_act_desc; + cudnnActivationDescriptor_t activation_desc_ = + scope_act_desc.descriptor(act_type); + // --------------- cudnn batchnorm workspace --------------- + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetBatchNormalizationBackwardExWorkspaceSize( + /*handle=*/dev_ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*xDesc=*/data_desc_, + /*yDesc=*/data_desc_, + /*dyDesc=*/data_desc_, + /*dzDesc=*/data_desc_, + /*dxDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/activation_desc_, + /*sizeInBytes=*/&workspace_size)); + + workspace_tensor.Resize({static_cast(workspace_size)}); + workspace_ptr = dev_ctx.template Alloc(&workspace_tensor); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBatchNormalizationBackwardEx( + /*handle=*/dev_ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*alphaDataDiff=*/CudnnDataType::kOne(), + /*betaDataDiff=*/CudnnDataType::kZero(), + /*alphaParamDiff=*/CudnnDataType::kOne(), + /*betaParamDiff=*/CudnnDataType::kZero(), + /*xDesc=*/data_desc_, + /*xData=*/x_ptr->template data(), + /*yDesc=*/data_desc_, + /*yData=*/y_ptr->template data(), + /*dyDesc=*/data_desc_, + /*dyData=*/d_y->template data(), + /*dzDesc=*/data_desc_, + /*dzData=*/d_z->template data(), + /*dxDesc=*/data_desc_, + /*dxData=*/d_x->template data(), + /*dBnScaleBiasDesc=*/bn_param_desc_, + /*bnScaleData=*/scale_ptr->template data>(), + /*bnBiasData=*/bias_ptr->template data>(), + /*dBnScaleData=*/d_scale->template data>(), + /*dBnBiasData=*/d_bias->template data>(), + /*epsilon=*/epsilon1, + /*savedMean=*/saved_mean_data, + /*savedInvVariance=*/saved_var_data, + /*activationDesmc=*/activation_desc_, + /*workspace=*/workspace_ptr, + /*workSpaceSizeInBytes=*/workspace_size, + /*reserveSpace=*/const_cast(reserve_space_ptr->template data()), + /*reserveSpaceSizeInBytes=*/reserve_space_size)); + + // clean when exit. + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_bn_add_activation_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBatchNormAddActGradKernel, + phi::dtype::float16) { + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu new file mode 100644 index 00000000000000..7b5b4119cf9705 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu @@ -0,0 +1,227 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" +#include "paddle/phi/kernels/fused_bn_add_activation_kernel.h" + +PHI_DECLARE_bool(cudnn_batchnorm_spatial_persistent); + +namespace phi { +namespace fusion { + +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; +template +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; + +template +void FusedBatchNormAddActKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &z, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &mean, + const DenseTensor &variance, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space) { +#if CUDNN_VERSION < 7401 + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_bn_add_activation operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif + bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; + PADDLE_ENFORCE_EQ(is_gpu_place, + true, + phi::errors::PreconditionNotMet("It must use CUDAPlace.")); + + double epsilon1 = static_cast(epsilon); + if (epsilon1 <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon1 = std::max(static_cast(epsilon1), CUDNN_BN_MIN_EPSILON); + + // Get the size for each dimension. + // NHWC [batch_size, in_height, in_width, in_channels] + const auto &in_dims = x.dims(); + + dev_ctx.template Alloc>( + mean_out, mean_out->numel() * sizeof(BatchNormParamType)); + dev_ctx.template Alloc>( + variance_out, variance_out->numel() * sizeof(BatchNormParamType)); + + dev_ctx.template Alloc>( + saved_mean, saved_mean->numel() * sizeof(BatchNormParamType)); + dev_ctx.template Alloc>( + saved_variance, saved_variance->numel() * sizeof(BatchNormParamType)); + + dev_ctx.template Alloc(y, y->numel() * sizeof(T)); + + int N, C, H, W, D; + const DataLayout data_layout = DataLayout::kNHWC; + phi::funcs::ExtractNCWHD(in_dims, data_layout, &N, &C, &H, &W, &D); + + // ------------------- cudnn descriptors --------------------- + auto handle = dev_ctx.cudnn_handle(); + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + + std::vector dims = {N, C, H, W, D}; + std::vector strides = {H * W * D * C, 1, W * D * C, D * C, C}; + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor( + data_desc_, + CudnnDataType::type, + in_dims.size() > 3 ? in_dims.size() : 4, + dims.data(), + strides.data())); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + double this_factor = 1. - momentum; + cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; + phi::backends::gpu::ScopedActivationDescriptor scope_act_desc; + cudnnActivationDescriptor_t activation_desc_ = + scope_act_desc.descriptor(act_type); + size_t workspace_size = 0; + size_t reserve_space_size = 0; + void *reserve_space_ptr = nullptr; + void *workspace_ptr = nullptr; + phi::DenseTensor workspace_tensor; + // Create reserve space and workspace for batch norm. + // Create tensor for each batchnorm op, it will be used in the + // backward. Thus this tensor shouldn't be temp. + PADDLE_ENFORCE_NOT_NULL( + reserve_space, + phi::errors::NotFound( + "The argument ReserveSpace of batch_norm op is not found.")); + + // --------------- cudnn batchnorm workspace --------------- + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*xDesc=*/data_desc_, + /*zDesc=*/data_desc_, + /*yDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/activation_desc_, + /*sizeInBytes=*/&workspace_size)); + + // -------------- cudnn batchnorm reserve space -------------- + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*activationDesc=*/activation_desc_, + /*xDesc=*/data_desc_, + /*sizeInBytes=*/&reserve_space_size)); + + reserve_space->Resize( + {static_cast((reserve_space_size + phi::SizeOf(x.dtype()) - 1) / + phi::SizeOf(x.dtype()))}); + reserve_space_ptr = dev_ctx.template Alloc( + reserve_space, reserve_space->numel() * sizeof(T)); + workspace_tensor.Resize({static_cast( + (workspace_size + phi::SizeOf(x.dtype()) - 1) / phi::SizeOf(x.dtype()))}); + workspace_ptr = dev_ctx.template Alloc( + &workspace_tensor, workspace_tensor.numel() * sizeof(T)); + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnBatchNormalizationForwardTrainingEx( + handle, + mode_, + bnOps_, + CudnnDataType::kOne(), + CudnnDataType::kZero(), + data_desc_, + x.template data(), + data_desc_, + z.template data(), + data_desc_, + y->template data(), + bn_param_desc_, + scale.template data>(), + bias.template data>(), + this_factor, + dev_ctx.template Alloc>( + mean_out, mean_out->numel() * sizeof(BatchNormParamType)), + dev_ctx.template Alloc>( + variance_out, + variance_out->numel() * sizeof(BatchNormParamType)), + epsilon1, + dev_ctx.template Alloc>( + saved_mean, saved_mean->numel() * sizeof(BatchNormParamType)), + dev_ctx.template Alloc>( + saved_variance, + saved_variance->numel() * sizeof(BatchNormParamType)), + activation_desc_, + workspace_ptr, + workspace_size, + reserve_space_ptr, + reserve_space_size)); + + // clean when exit. + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_bn_add_activation, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBatchNormAddActKernel, + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index ad276ec6f1812b..3b73935699babb 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -1255,6 +1255,9 @@ PD_REGISTER_KERNEL(batch_norm, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); } +#if CUDNN_VERSION_MIN(7, 4, 1) + kernel->OutputAt(5).SetDataType(phi::DataType::UINT8); +#endif } #else PD_REGISTER_KERNEL(batch_norm, @@ -1274,6 +1277,9 @@ PD_REGISTER_KERNEL(batch_norm, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); } +#if CUDNN_VERSION_MIN(7, 4, 1) + kernel->OutputAt(5).SetDataType(phi::DataType::UINT8); +#endif } #endif diff --git a/paddle/phi/kernels/gpu/quantize_linear_kernel.cu b/paddle/phi/kernels/gpu/quantize_linear_kernel.cu new file mode 100644 index 00000000000000..11c043e76f464e --- /dev/null +++ b/paddle/phi/kernels/gpu/quantize_linear_kernel.cu @@ -0,0 +1,130 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/kernels/quantize_linear_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/quantize_linear_impl.h" + +namespace phi { + +template +__global__ void KeDequantize( + const T* in, const T* scale, T max_range, int64_t num, T* out) { + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + out[i] = in[i] * scale[0] / max_range; + } +} + +template +__global__ void DequantizeOneScaleQuantAxisN(const T* in, + const T* scale, + const T max_range, + const int64_t num, + const int n_scales, + const int quant_stride, + T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % n_scales]; + out[i] = in[i] * s / max_range; + } +} + +template +struct ChannelDequantizeFunctorV2 { + void operator()(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* in, + const phi::DenseTensor* scale, + T max_range, + const int quant_axis, + phi::DenseTensor* out) { + auto in_dims = in->dims(); + const T* in_data = in->data(); + T* out_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + int64_t num = in->numel(); + const T* scale_factor = scale->data(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + + DequantizeOneScaleQuantAxisN + <<>>(in_data, + scale_factor, + max_range, + num, + in_dims[quant_axis], + quant_stride, + out_data); + } +}; + +template +struct DequantizeFunctor { + void operator()(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* in, + const phi::DenseTensor* scale, + T max_range, + phi::DenseTensor* out) { + const T* in_data = in->data(); + const T* scale_factor = scale->data(); + T* out_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + + int64_t num = in->numel(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + KeDequantize<<>>( + in_data, scale_factor, max_range, num, out_data); + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct ChannelDequantizeFunctorV2; +template struct ChannelDequantizeFunctorV2; +template struct ChannelDequantizeFunctorV2; +} // namespace phi + +PD_REGISTER_KERNEL(dequantize_linear, + GPU, + ALL_LAYOUT, + phi::DeQuantizeLinearKernel, + float, + int8_t, + double, + phi::dtype::float16) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/impl/quantize_linear_impl.h b/paddle/phi/kernels/impl/quantize_linear_impl.h new file mode 100644 index 00000000000000..9f86fd07447ee5 --- /dev/null +++ b/paddle/phi/kernels/impl/quantize_linear_impl.h @@ -0,0 +1,127 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/kernels/quantize_linear_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/cast_kernel.h" + +namespace phi { + +template +struct DequantizeFunctor { + void operator()(const Context& dev_ctx, + const phi::DenseTensor* in, + const phi::DenseTensor* scale, + T max_range, + phi::DenseTensor* out); +}; + +template +struct ChannelDequantizeFunctorV2 { + void operator()(const Context& dev_ctx, + const phi::DenseTensor* in, + const phi::DenseTensor** scales, + const int scale_num, + T max_range, + const int quant_axis, + phi::DenseTensor* out); +}; + +template +void DeQuantizeLinearImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + int quant_axis, + int bit_length, + bool only_observer, + DenseTensor* out) { + auto* in = &x; + + auto in_tmp = phi::Cast(dev_ctx, *in, phi::CppTypeToDataType::Type()); + + dev_ctx.template Alloc(out, out->numel() * sizeof(D)); + + if (only_observer) { + phi::Copy(dev_ctx, *in, dev_ctx.GetPlace(), false, out); + return; + } + + if (quant_axis < 0) { + float max_range = (std::pow(2, bit_length - 1) - 1); + DequantizeFunctor()( + dev_ctx, &in_tmp, &scale, static_cast(max_range), out); + } else { + PADDLE_ENFORCE_EQ( + scale.numel(), + in_tmp.dims()[quant_axis], + phi::errors::PreconditionNotMet( + "The number of first scale values must be the same with " + "quant_axis dimension value of Input(X) when the `scale` has " + "only one element, but %ld != %ld here.", + scale.numel(), + in_tmp.dims()[quant_axis])); + int max_range = (std::pow(2, bit_length - 1) - 1); + + ChannelDequantizeFunctorV2()( + dev_ctx, &in_tmp, &scale, static_cast(max_range), quant_axis, out); + } +} + +template +void DeQuantizeLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& zero_point, + const paddle::optional& in_accum, + const paddle::optional& in_state, + int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, + DenseTensor* out, + DenseTensor* out_state, + DenseTensor* out_accum, + DenseTensor* out_scale) { + switch (scale.dtype()) { + case phi::DataType::FLOAT64: + DeQuantizeLinearImpl( + dev_ctx, x, scale, quant_axis, bit_length, only_observer, out); + break; + case phi::DataType::FLOAT32: + DeQuantizeLinearImpl( + dev_ctx, x, scale, quant_axis, bit_length, only_observer, out); + break; + case phi::DataType::FLOAT16: + DeQuantizeLinearImpl( + dev_ctx, x, scale, quant_axis, bit_length, only_observer, out); + break; + default: + PADDLE_THROW(phi::errors::Unimplemented( + "In DeQuantizeLinearKernel, " + "data type %d for scale/output is not supported ", + scale.dtype())); + break; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/quantize_linear_kernel.h b/paddle/phi/kernels/quantize_linear_kernel.h new file mode 100644 index 00000000000000..c10a67f51e6030 --- /dev/null +++ b/paddle/phi/kernels/quantize_linear_kernel.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DeQuantizeLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& zero_point, + const paddle::optional& in_accum, + const paddle::optional& in_state, + int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, + DenseTensor* out, + DenseTensor* out_state, + DenseTensor* out_accum, + DenseTensor* out_scale); + +} // namespace phi diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc index b95dda1fed13d1..e2f2d28182b67d 100644 --- a/paddle/phi/kernels/xpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -140,4 +140,9 @@ PD_REGISTER_KERNEL(batch_norm, ALL_LAYOUT, phi::BatchNormKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc b/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc new file mode 100644 index 00000000000000..a9adffca84700b --- /dev/null +++ b/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FusedBatchNormAddActOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("fused_bn_add_activation", + {"X", "Z", "Scale", "Bias", "Mean", "Variance"}, + {"momentum", "epsilon", "act_type"}, + {"Y", + "MeanOut", + "VarianceOut", + "SavedMean", + "SavedVariance", + "ReserveSpace"}); +} + +KernelSignature FusedBatchNormAddActGradOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("fused_bn_add_activation_grad", + {"X", + "Y", + "Y@GRAD", + "Scale", + "Bias", + "SavedMean", + "SavedVariance", + "ReserveSpace"}, + {"momentum", "epsilon", "act_type"}, + {"X@GRAD", "Z@GRAD", "Scale@GRAD", "Bias@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_bn_add_activation, + phi::FusedBatchNormAddActOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fused_bn_add_activation_grad, + phi::FusedBatchNormAddActGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/quantize_linear_sig.cc b/paddle/phi/ops/compat/quantize_linear_sig.cc new file mode 100644 index 00000000000000..75e523bf55367d --- /dev/null +++ b/paddle/phi/ops/compat/quantize_linear_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DeQuantizeLinearOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature( + "dequantize_linear", + {"X", "Scale", "ZeroPoint", "InAccum", "InState"}, + {"quant_axis", "bit_length", "round_type", "is_test", "only_observer"}, + {"Y", "OutState", "OutAccum", "OutScale"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(dequantize_linear, + phi::DeQuantizeLinearOpArgumentMapping); diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 7aa25386076e54..2548ca079f02cd 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1283,6 +1283,7 @@ set(STATIC_BUILD_TESTS test_adamw_op test_arg_min_max_op test_assign_pos_op + test_batch_norm_op test_bucketize_api test_bincount_op test_c_embedding_op @@ -1290,6 +1291,7 @@ set(STATIC_BUILD_TESTS test_decoupled_py_reader test_eig_op test_eigh_op + test_fake_dequantize_op test_fake_quantize_op test_fetch_lod_tensor_array test_ftrl_op diff --git a/test/legacy_test/test_fake_dequantize_op.py b/test/legacy_test/test_fake_dequantize_op.py index ee2f7f7b0820ab..9fc5f3500844f1 100644 --- a/test/legacy_test/test_fake_dequantize_op.py +++ b/test/legacy_test/test_fake_dequantize_op.py @@ -247,7 +247,7 @@ def setUp(self): self.outputs = {'Y': ydq} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestChannelWiseDequantizeOp1(TestChannelWiseDequantizeOp): @@ -281,7 +281,7 @@ def setUp(self): self.outputs = {'Y': ydq} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestDequantizeOpDouble(TestDequantizeOp):