Skip to content

Commit

Permalink
remove batch norm from StaticBuildBlackList (PaddlePaddle#57510)
Browse files Browse the repository at this point in the history
* remove batch norm from StaticBuildBlackList

* turn off the flag

* fix batch_norm register kernel info

* remove dequantize_linear to phi

* fix build error

* add sig file

* update test and cmakelist

* fix test_split_program error in static_build mode

* move fuse_bn_add_act to phi

* fix error in test_cudnn_bn_add_relu

* fix shape error

* fix date

* close the static_build flag
  • Loading branch information
AndSonder authored and Frida-a committed Oct 14, 2023
1 parent 084a65d commit 813d567
Show file tree
Hide file tree
Showing 20 changed files with 1,049 additions and 489 deletions.
17 changes: 16 additions & 1 deletion paddle/fluid/framework/new_executor/interpreter/static_build.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ std::set<std::string> OpsCanSkipedFakeAllocInStaticBuild = {
"nop"};

std::set<std::string> StaticBuildBlackList = {
"batch_norm" /*: to handle reserve_space output*/,
"cinn_instruction_run" /*: to handle subgraph infermeta*/,
"cinn_launch" /*: to handle subgraph infermeta*/,
"run_program" /*: to handle scope output*/,
Expand Down Expand Up @@ -206,6 +205,14 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op,
}
}

if (op_type == "batch_norm" && parameter_name == "ReserveSpace") {
if (dynamic_cast<const OperatorWithKernel*>(&op)->kernel_type()->place_ ==
phi::CPUPlace()) {
VLOG(2) << "Skip fake initialization for: " << parameter_name;
return false;
}
}

if (op_type == "coalesce_tensor" && parameter_name == "Output") {
VLOG(2) << "Skip fake initialization for: " << parameter_name;
return false;
Expand Down Expand Up @@ -250,6 +257,12 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op,
}
}

if ((op_type == "flatten" || op_type == "flatten_contiguous_range") &&
parameter_name == "XShape") {
VLOG(2) << "Skip fake initialization for: " << parameter_name;
return false;
}

if (op_type == "segment_pool" && parameter_name == "SummedIds") {
return op.Attr<std::string>("pooltype") == "MEAN" &&
dynamic_cast<const OperatorWithKernel*>(&op)
Expand Down Expand Up @@ -856,6 +869,8 @@ void FakeInitializeOutputsForFunctionKernel(
dtype = InferDTypeFromAttr(op, runtime_ctx, "dtype");
} else if (op_type == "bincount" || op_type == "reduce_sum_grad") {
dtype = GetInputDType(runtime_ctx, "X");
} else if (op_type == "dequantize_linear") {
dtype = GetInputDType(runtime_ctx, "Scale");
} else if (op_type == "lamb") {
bool multi_precision = op.Attr<bool>("multi_precision");
dtype = GetInputDType(runtime_ctx, "Moment1");
Expand Down
Loading

0 comments on commit 813d567

Please sign in to comment.