diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index a3bb3a28793007..cb7df4347fa2db 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -21,6 +21,8 @@ #include "paddle/fluid/eager/to_static/run_program_op_node.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/value.h" // Filter params without grads in global block. In this case, we will // tag its AutogradMeta with stop_gradient = True to avoid fault from @@ -90,6 +92,23 @@ static std::vector filter_unused_input_var_in_backward( return filter_x; } +static std::vector newir_filter_unused_input_var_in_backward( + const std::vector& x, + const std::string x_key_name, + const paddle::framework::AttributeMap& attrs) { + auto values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at(x_key_name)); + auto filter_x = std::vector(x); + for (size_t i = 0; i < x.size(); i++) { + if (values[i].impl() == nullptr) { + auto fake = paddle::Tensor(std::make_shared()); + fake.set_name(paddle::framework::kFakeVarName); + filter_x[i] = fake; + } + } + return filter_x; +} + static std::vector Trans2ContiguousTensors( const std::vector& tensors) { std::vector res; @@ -243,8 +262,17 @@ inline void newir_run_program_ad_func( paddle::Tensor(std::make_shared()); middles.push_back(&grad_node->GetMiddle()[i]); } + + auto backward_outs = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo")); for (size_t i = 0; i < output_size; ++i) { - grad_node->GetOutputs()[i] = *out[i]; + if (backward_outs[i] != nullptr) { + grad_node->GetOutputs()[i] = *out[i]; + } else { // not used by backward program + auto fake = paddle::Tensor(std::make_shared()); + fake.set_name(paddle::framework::kFakeVarName); + grad_node->GetOutputs()[i] = fake; + } } } @@ -253,35 +281,26 @@ inline void newir_run_program_ad_func( NewIRRunProgramAPI( x, params, out, middles, step_scope, dout, require_any_grad, attrs); if (require_any_grad) { - // auto x_names = - // PADDLE_GET_CONST(std::vector, attrs.at("x_names")); - egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); // Set Attributes grad_node->SetAttrMap(attrs); - // auto* forward_global_block = PADDLE_GET_CONST( - // paddle::framework::BlockDesc*, attrs.at("forward_global_block")); - // auto* backward_global_block = PADDLE_GET_CONST( - // paddle::framework::BlockDesc*, attrs.at("backward_global_block")); // Clear unused x vars - // auto filter_x = - // filter_unused_input_var_in_backward(x, x_names, backward_global_block); + auto filter_x = newir_filter_unused_input_var_in_backward(x, "bx", attrs); // Set TensorWrappers - grad_node->SetFwdX(x); - // Clear unused out vars - // clear_unused_out_var_in_backward(out, backward_global_block, - // step_scope[0]); + grad_node->SetFwdX(filter_x); + + auto filter_params = + newir_filter_unused_input_var_in_backward(params, "bp", attrs); + grad_node->SetFwdParams(filter_params); - grad_node->SetFwdParams(params); grad_node->SetStepScope(step_scope); // just for set useable. // Set Grad out rank as same as fwd input and set stop gradient to bwd // NOTE(@xiongkun): Not every tensor in x(list of tensor) is required // gradient. for example: x[1] is not used for output, the x[1] is ignored. - // TODO(@xiongkun): rewrite by new ir representation. std::vector x_require_grad; for (size_t i = 0; i < x.size(); ++i) { x_require_grad.push_back(&x[i]); @@ -290,6 +309,7 @@ inline void newir_run_program_ad_func( grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0); grad_node->SetGradOutMeta(params, /*slot id*/ 1); + // TODO(@xiongkun): rewrite by new ir representation. // VLOG(2) << "clear_no_grad_edges."; // clear_no_grad_edges_with_partial_block(params, // forward_global_block, diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index fd0d6563945a57..708f47af6cfeac 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -89,15 +89,6 @@ static void CheckInputVarStatus(const Tensor &tensor) { "RunProgram(Grad)Op holds " "wrong type. Expect type is DenseTensor.", tensor.name())); - - PADDLE_ENFORCE_EQ( - static_cast(tensor.impl().get())->IsInitialized(), - true, - paddle::platform::errors::InvalidArgument( - "The tensor in input tensor %s of " - "RunProgram(Grad)Op " - "is not initialized.", - tensor.name())); } static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, @@ -117,13 +108,6 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, "RunProgram(Grad)Op's internal scope holds " "wrong type. Expect type is DenseTensor", name)); - PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(), - true, - paddle::platform::errors::InvalidArgument( - "The tensor in output tensor %s get from " - "RunProgram(Grad)Op's internal " - "scope is not initialized.", - name)); } else if (dst_tensor.is_selected_rows()) { auto &src_tensor = src_var.Get(); PADDLE_ENFORCE_EQ(phi::SelectedRows::classof(&src_tensor), @@ -133,14 +117,6 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, "RunProgram(Grad)Op's internal scope holds " "wrong type. Expect type is SelectedRows", name)); - PADDLE_ENFORCE_EQ(src_tensor.initialized(), - true, - paddle::platform::errors::InvalidArgument( - "The tensor in output tensor %s get from " - "RunProgram(Grad)Op's " - "internal scope is not initialized.", - name)); - } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The RunProgram(Grad)Op only support output " @@ -214,14 +190,23 @@ static auto GetNameFromValue(const ::pir::Block *block, .dyn_cast() .AsString(); value2name[op->operand(0).source()] = name; + } else if (op->name() == "builtin.get_parameter") { + name = op->attributes() + .at("parameter_name") + .dyn_cast() + .AsString(); + value2name[op->result(0).Value::impl()] = name; } } std::vector names; - std::transform( - values.begin(), - values.end(), - std::back_inserter(names), - [&value2name](const ::pir::Value &v) { return value2name[v]; }); + std::transform(values.begin(), + values.end(), + std::back_inserter(names), + [&value2name](const ::pir::Value &v) { + if (!value2name.count(v)) + return std::string(paddle::framework::kFakeVarName); + return value2name.at(v); + }); return names; } @@ -255,7 +240,7 @@ static void ShareTensorsFromScope( auto &src_tensor = var->Get(); auto *dst_tensor = const_cast( dynamic_cast(tensors[i]->impl().get())); - VLOG(2) << "share " << name << " from scope"; + VLOG(4) << "share " << name << " from scope"; *dst_tensor = src_tensor; } else if (var->IsType()) { auto &src_tensor = var->Get(); @@ -272,6 +257,11 @@ static void ShareTensorsIntoScopeByValue( const std::vector<::pir::Value> &values, paddle::framework::Scope *scope) { auto names = GetNameFromValue(block, values); + if (VLOG_IS_ON(4)) { + for (auto &s : names) { + VLOG(4) << "ShareTensorIntoScopeByValue name: " << s; + } + } ShareTensorsIntoScopeWithName(tensors, names, scope); } @@ -461,8 +451,6 @@ inline void NewIRRunProgramAPI( PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")); auto param_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp")); - // auto dout_names = - // PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp")); auto *forward_global_block = PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block")); @@ -523,6 +511,15 @@ inline void NewIRRunProgramAPI( std::set(skip_names.begin(), skip_names.end()); skip_names = details::GetNameFromValue(forward_global_block, output_values); skip_names_set.insert(skip_names.begin(), skip_names.end()); + auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>, + attrs.at("no_need_buffers")); + auto no_need_buffer_names = + details::GetNameFromValue(forward_global_block, no_need_buffer_values); + VLOG(4) << "start skip no need buffer vars with name:"; + for (auto &name : no_need_buffer_names) { + VLOG(4) << "Skip no need buffer vars with name:" << name; + skip_names_set.erase(name); + } details::print_collection(skip_names_set); interpreter_core->SetSkipGcVars(skip_names_set); @@ -997,6 +994,8 @@ inline void NewIRRunProgramGradAPI( PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bx")); auto forward_middle_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bm")); + auto parameter_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp")); auto forward_output_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo")); auto x_grad_values = @@ -1004,6 +1003,20 @@ inline void NewIRRunProgramGradAPI( auto p_grad_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp_g")); + // share x, param, middles, output_grads, out into scope. + details::ShareTensorsIntoScopeByValue( + backward_global_block, out_grad, output_grad_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue( + backward_global_block, x, forward_input_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue(backward_global_block, + middles, + forward_middle_values, + global_inner_scope); + details::ShareTensorsIntoScopeByValue( + backward_global_block, out, forward_output_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue( + backward_global_block, params, parameter_values, global_inner_scope); + auto &interpretercore_info_cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = @@ -1016,19 +1029,6 @@ inline void NewIRRunProgramGradAPI( 1); VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; // Step 1. share input_vars & parameters into scope - // x, param, middles, output_grads - details::ShareTensorsIntoScopeByValue(backward_global_block, - out_grad, - output_grad_values, - global_inner_scope); - details::ShareTensorsIntoScopeByValue( - backward_global_block, x, forward_input_values, global_inner_scope); - details::ShareTensorsIntoScopeByValue(backward_global_block, - middles, - forward_middle_values, - global_inner_scope); - details::ShareTensorsIntoScopeByValue( - backward_global_block, out, forward_output_values, global_inner_scope); auto kernel_backward_program = paddle::dialect::PdOpLowerToKernelPass(backward_program, place); interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( @@ -1076,14 +1076,13 @@ inline void NewIRRunProgramGradAPI( program_id, global_inner_scope, /*is_grad=*/true); interpreter_core = cached_value.core_; - // update scope (TODO: why share again) - // details::ShareTensorsIntoScope(out_grad, global_inner_scope); - // if (interpreter_core->GetVariableScope()->GetMutableScope() != - // global_inner_scope) { - // details::BuildScopeByBlock( - // *interpreter_core.get(), *backward_global_block, global_inner_scope); - // interpreter_core->reset_scope(global_inner_scope); - //} + if (interpreter_core->GetVariableScope()->GetMutableScope() != + global_inner_scope) { + // update scope (TODO(xiongkun): do we need this??) + // details::BuildScopeByBlock( + // *interpreter_core.get(), *backward_global_block, global_inner_scope); + interpreter_core->reset_scope(global_inner_scope); + } } if (!backward_global_block->empty()) { @@ -1287,7 +1286,7 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { ~NewIRGradNodeRunProgram() override { if (!executed_) { auto *out_scope_vec = &step_scope_; - VLOG(4) << "~GradNodeRunProgram"; + VLOG(4) << "~NewIRGradNodeRunProgram"; // Normally out_scope_vec.size() == 1. for safty, we add for-loop here. for (size_t i = 0; i < out_scope_vec->size(); ++i) { paddle::framework::Scope *global_inner_scope = out_scope_vec->at(i); @@ -1306,7 +1305,7 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { egr::kSlotSmallVectorSize> &grads, // NOLINT bool create_graph UNUSED, bool is_new_grad UNUSED) override { - VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; + VLOG(3) << "Running Eager Backward Node: NewIRGradNodeRunProgram"; paddle::small_vector, egr::kSlotSmallVectorSize> hooked_grads = NewIRGradNodeRunProgram::ApplyGradientHooks(grads); PADDLE_ENFORCE_EQ(hooked_grads.size(), @@ -1348,7 +1347,6 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { "The hooked_grads[0].size() and " "out_grad_values.size() should be equal.")); - VLOG(1) << "Run Program Grad API start."; NewIRRunProgramGradAPI(x_, params_, hooked_grads[0], @@ -1358,8 +1356,7 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase { attrs_, x_grad_ptr, params_grad_ptr); - VLOG(1) << "Run Program Grad API end."; - VLOG(3) << "End Eager Backward Node: GradNodeRunProgram"; + VLOG(3) << "End Eager Backward Node: NewIRGradNodeRunProgram"; executed_ = true; return {x_grad, params_grad}; diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 9d8074628fb13d..d28dc9bec40088 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -1027,6 +1027,7 @@ void ConstructAttrMapForRunProgram( "fm", "fo", "bx", + "no_need_buffers", "bp", "bm", "bo_g", diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index ed680cfb58803a..aa27aecfdd7b7e 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -602,6 +602,14 @@ void BindOpResult(py::module *m) { return false; } }) + .def("is_selected_row_type", + [](OpResult &self) { + if (self.type().isa()) { + return true; + } else { + return false; + } + }) .def_property( "stop_gradient", [](OpResult &self) { @@ -678,7 +686,9 @@ Operation *BuildOpFrom( pir::OperationArgument to_create_argument(to_copy_op->info()); to_create_argument.attributes = to_copy_op->attributes(); + VLOG(6) << "start copy op: " << to_copy_op->name(); auto origin_results = to_copy_op->results(); + VLOG(6) << "start translate origin results into op type."; std::transform(origin_results.begin(), origin_results.end(), std::back_inserter(to_create_argument.output_types), @@ -688,6 +698,7 @@ Operation *BuildOpFrom( }); // transform by value_map dict. + VLOG(6) << "start create op."; auto origin_operands = to_copy_op->operands(); std::transform(origin_operands.begin(), origin_operands.end(), @@ -698,8 +709,6 @@ Operation *BuildOpFrom( }); auto *cloned_op = Operation::Create(std::move(to_create_argument)); - // update the mapping of value_map. std::transform is a map(func, - // zip()). std::vector tmp; std::transform(origin_results.begin(), origin_results.end(), @@ -742,11 +751,11 @@ void range_block_do(const Block *block, std::vector range, F fn) { } } -std::vector AnalysisMiddleVariable( - const Program &program, - const std::vector &forward_inputs, - const std::vector &forward_range, - const std::vector &backward_range) { +std::pair, std::unordered_set> +AnalysisMiddleVariable(const Program &program, + const std::vector &forward_inputs, + const std::vector &forward_range, + const std::vector &backward_range) { std::vector middle_values; std::unordered_set backward_inputs; @@ -769,7 +778,7 @@ std::vector AnalysisMiddleVariable( middle_values.push_back(v); } }); - return middle_values; + return std::make_pair(middle_values, backward_inputs); } void mapping_value(const std::vector &origin, @@ -780,6 +789,11 @@ void mapping_value(const std::vector &origin, std::back_inserter(out), [&value_map](const pir::Value &v) { if (v.impl() == nullptr) return Value(nullptr); + if (!value_map.count(v)) { + VLOG(2) << "mapping value found v is not exist. may not " + "used by backward program."; + return Value(nullptr); + } return value_map.at(v); }); } @@ -793,18 +807,62 @@ pir::OpResult FakeOpResult() { return pir::OpResult(nullptr); } +bool IsFakeOpResult(const pir::OpResult &result) { + // create a fake opresults to simplify `ForwardBackwardSplit`. + return result.Value::impl() == nullptr; +} + +static auto GetNoNeedBufferValue(const ::pir::Block *whole_block, + std::vector range) { + // filter no need buffer values. + std::unordered_set<::pir::Value> need_buffer_values; + std::unordered_set<::pir::Value> no_need_buffer_values; + range_block_do( + whole_block, range, [&need_buffer_values](::pir::Operation *op) { + if (op->HasInterface() == false) { + // not a OpYamlInfoInterface, can't have no_need_buffer. + for (const auto &operand : op->operands_source()) { + need_buffer_values.insert(operand); + } + } else { + auto opinfo = + op->dyn_cast().GetOpInfo(); + int counter = 0; + for (const auto &op_input_info : std::get<0>(opinfo)) { + if (!op_input_info.no_need_buffer) { + need_buffer_values.insert(op->operand_source(counter)); + } + counter += 1; + } + } + }); + range_block_do(whole_block, + range, + [&need_buffer_values, + &no_need_buffer_values](const ::pir::Operation *op) { + for (const auto &operand : op->operands_source()) { + if (need_buffer_values.count(operand) == 0) { + no_need_buffer_values.insert(operand); + } + } + }); + return std::vector<::pir::Value>(no_need_buffer_values.begin(), + no_need_buffer_values.end()); +} + SplitedResult ForwardBackwardSplit( const Program &program, const std::vector &op_result_forward_inputs, + const std::vector &op_result_forward_params, const std::vector &op_result_forward_outputs, const std::vector &op_result_forward_inputs_grads, + const std::vector &op_result_forward_params_grads, const std::vector &op_result_forward_outputs_grads, const std::vector &forward_range, const std::vector &backward_range) { // transform opresult -> value - VLOG(1) << "Start Prepare data structures."; std::vector forward_inputs, forward_outputs, forward_inputs_grads, - forward_outputs_grads; + forward_outputs_grads, forward_params, forward_params_grads; auto op_result_to_value = [](const pir::OpResult &r) { if (r.impl() == nullptr) return Value(nullptr); @@ -827,56 +885,67 @@ SplitedResult ForwardBackwardSplit( op_result_forward_outputs_grads.end(), std::back_inserter(forward_outputs_grads), op_result_to_value); + std::transform(op_result_forward_params.begin(), + op_result_forward_params.end(), + std::back_inserter(forward_params), + op_result_to_value); + std::transform(op_result_forward_params_grads.begin(), + op_result_forward_params_grads.end(), + std::back_inserter(forward_params_grads), + op_result_to_value); std::vector forward_in_out_values; for (auto &v : std::vector *>( - {&forward_inputs, &forward_outputs})) { + {&forward_inputs, &forward_outputs, &forward_params})) { forward_in_out_values.insert( forward_in_out_values.end(), v->begin(), v->end()); } std::vector fx, fp, fm, fo, bx, bp, bm, bo_g, bx_g, bp_g, bo; + std::vector no_need_buffer_values; pir::IrContext *ctx = pir::IrContext::Instance(); auto forward_program = std::make_shared(ctx); auto backward_program = std::make_shared(ctx); - auto middle_values = AnalysisMiddleVariable( + std::vector middle_values; + std::unordered_set backward_inputs; + std::tie(middle_values, backward_inputs) = AnalysisMiddleVariable( program, forward_in_out_values, forward_range, backward_range); std::unordered_map forward_value_map; std::unordered_map backward_value_map; pir::Builder backward_builder = pir::Builder(ctx, backward_program->block()); // forward program construct. - VLOG(1) << "Before Forward Construct."; + VLOG(4) << "start create forward program."; range_block_do(program.block(), forward_range, [&forward_value_map, &forward_program](Operation *op) { auto *cloned_op = BuildOpFrom(op, forward_value_map); forward_program->block()->push_back(cloned_op); }); - VLOG(1) << "After Forward Construct."; - // backward program construc. // Step1. insert data op for inputs_values and middle_values int counter = 0; - auto create_data_fn = - [&backward_builder, &backward_value_map, &counter](const pir::Value &v) { - if (v.impl() == nullptr) { - return; - } - auto value_type = v.type().dyn_cast(); - auto dtype = paddle::dialect::TransToPhiDataType(value_type.dtype()); - auto shape = phi::vectorize(value_type.dims()); - auto place = phi::Place(); - - paddle::dialect::DataOp op = - backward_builder.Build( - std::string("input_") + std::to_string(counter), - shape, - dtype, - place); - counter += 1; - backward_value_map[v] = op->results()[0].Value::impl(); - }; + auto create_data_fn = [&backward_builder, + &backward_inputs, + &backward_value_map, + &counter](const pir::Value &v) { + if (v.impl() == nullptr || !backward_inputs.count(v)) { + return; + } + auto value_type = v.type().dyn_cast(); + auto dtype = paddle::dialect::TransToPhiDataType(value_type.dtype()); + auto shape = phi::vectorize(value_type.dims()); + auto place = phi::Place(); + + paddle::dialect::DataOp op = + backward_builder.Build( + std::string("input_") + std::to_string(counter), + shape, + dtype, + place); + counter += 1; + backward_value_map[v] = op->results()[0].Value::impl(); + }; auto create_output_fn_forward = [&ctx, &forward_value_map, @@ -916,44 +985,59 @@ SplitedResult ForwardBackwardSplit( counter += 1; }; - counter = 0; + // counter = 0; + VLOG(4) << "start create backward inputs, inserting pd.data ops."; + VLOG(4) << "Create pd.data for backward program: fo, start with input_" + << counter; std::for_each(forward_outputs.begin(), forward_outputs.end(), create_data_fn); + VLOG(4) << "Create pd.data for backward program: fx, start with input_" + << counter; std::for_each(forward_inputs.begin(), forward_inputs.end(), create_data_fn); + VLOG(4) << "Create pd.data for backward program: fp, start with input_" + << counter; + std::for_each(forward_params.begin(), forward_params.end(), create_data_fn); + VLOG(4) << "Create pd.data for backward program: fm, start with input_" + << counter; std::for_each(middle_values.begin(), middle_values.end(), create_data_fn); + VLOG(4) << "Create pd.data for backward program: fo_g, start with input_" + << counter; std::for_each(forward_outputs_grads.begin(), forward_outputs_grads.end(), create_data_fn); - VLOG(1) << "After create pd.data for backward program."; + VLOG(4) << "Create pd.data for backward program end. input_" << counter; - counter = 0; + // counter = 0; + VLOG(4) << "start create forward outputs, inserting set_parameter ops."; std::for_each( middle_values.begin(), middle_values.end(), create_output_fn_forward); std::for_each( forward_outputs.begin(), forward_outputs.end(), create_output_fn_forward); - VLOG(1) << "After call create_output_fn"; // Step2. copy backward ops . + VLOG(4) << "start copy backward ops"; range_block_do(program.block(), backward_range, [&backward_value_map, &backward_program](Operation *op) { auto *cloned_op = BuildOpFrom(op, backward_value_map); backward_program->block()->push_back(cloned_op); }); - VLOG(1) << "After call backward copy"; - counter = 0; + // counter = 0; + VLOG(4) << "start create backward outputs, inserting set_parameter ops."; std::for_each(forward_inputs_grads.begin(), forward_inputs_grads.end(), create_output_fn_backward); - // TODO(xiongkun): add forward parameter grads. + std::for_each(forward_params_grads.begin(), + forward_params_grads.end(), + create_output_fn_backward); - VLOG(1) << "forward_value_map.size() is " << forward_value_map.size(); - VLOG(1) << "backward_value_map.size() is " << backward_value_map.size(); + VLOG(4) << "forward_value_map.size() is " << forward_value_map.size(); + VLOG(4) << "backward_value_map.size() is " << backward_value_map.size(); std::ostringstream print_stream; print_stream << "ForwardProgram is :\n"; forward_program->Print(print_stream); print_stream << "BackwardProgram is:\n"; backward_program->Print(print_stream); - VLOG(1) << "Splited Program (fwd | bwd): \n" << print_stream.str(); + VLOG(4) << "Splited Program (fwd | bwd): \n" << print_stream.str(); // construct all attributes we needed. @@ -961,26 +1045,33 @@ SplitedResult ForwardBackwardSplit( mapping_value(middle_values, backward_value_map, bm); // write 'bm' mapping_value(forward_inputs, forward_value_map, fx); // write 'fx' mapping_value(forward_inputs, backward_value_map, bx); // write 'bx' + mapping_value(forward_params, forward_value_map, fp); // write 'fp' + mapping_value(forward_params, backward_value_map, bp); // write 'bp' mapping_value(forward_outputs, forward_value_map, fo); // write 'fo' - mapping_value(forward_inputs_grads, - backward_value_map, - bx_g); // write 'fx_g' - mapping_value(forward_outputs_grads, - backward_value_map, - bo_g); // write 'bo_g' + mapping_value( + forward_inputs_grads, backward_value_map, bx_g); // write 'bx_g' + mapping_value( + forward_params_grads, backward_value_map, bp_g); // write 'bp_g' + mapping_value( + forward_outputs_grads, backward_value_map, bo_g); // write 'bo_g' mapping_value(forward_outputs, backward_value_map, bo); // write 'bo' - - std::map> attr = {{"fx", fx}, - {"fp", fp}, - {"fm", fm}, - {"fo", fo}, - {"bx", bx}, - {"bp", bp}, - {"bm", bm}, - {"bo_g", bo_g}, - {"bx_g", bx_g}, - {"bp_g", bp_g}, - {"bo", bo}}; + mapping_value(GetNoNeedBufferValue(program.block(), backward_range), + forward_value_map, + no_need_buffer_values); // write 'no_need_buffers' + + std::map> attr = { + {"fx", fx}, + {"fp", fp}, + {"fm", fm}, + {"fo", fo}, + {"bx", bx}, + {"bp", bp}, + {"bm", bm}, + {"bo_g", bo_g}, + {"bx_g", bx_g}, + {"bp_g", bp_g}, + {"no_need_buffers", no_need_buffer_values}, + {"bo", bo}}; std::vector> programs = {forward_program, backward_program}; return std::make_pair(programs, attr); @@ -990,6 +1081,7 @@ void BindUtils(pybind11::module *m) { m->def("program_clone", ProgramClone); m->def("program_split", ForwardBackwardSplit); m->def("fake_op_result", FakeOpResult); + m->def("is_fake_op_result", IsFakeOpResult); m->def("set_global_program", [](Program *program) { APIBuilder::Instance().SetProgram(program); }); m->def("set_insertion_point", diff --git a/paddle/phi/kernels/impl/data_impl.h b/paddle/phi/kernels/impl/data_impl.h index 29611a8cfe887c..a8c3b1fa86650d 100644 --- a/paddle/phi/kernels/impl/data_impl.h +++ b/paddle/phi/kernels/impl/data_impl.h @@ -28,6 +28,9 @@ void ShadowFeedKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { ctx.template Alloc(out); + if (!x.initialized()) { + return; + } if (x.place() == out->place()) { out->ShareDataWith(x); out->set_lod(x.lod()); diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index 21e0357c700fc6..be3c47a70d979a 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -354,16 +354,7 @@ void SplitOp::Verify() const { input_type.size()); // for all i in outputs.size(): outputs[i].type == inputs[0][i].type - for (size_t i = 0; i < output_num; ++i) { - auto type = (*this)->result(i).type(); - IR_ENFORCE(input_type[i] == type, - "The type %s of inputs[0][%d] must be " - "equal to type %s of outputs[%d].", - input_type[i], - i, - type, - i); - } + // TODO(@xiongkun) consult zhangbo to check what to do with null type. } const char *ConstantOp::attributes_name[attributes_num] = {"value"}; // NOLINT diff --git a/python/paddle/base/dygraph/base.py b/python/paddle/base/dygraph/base.py index 9c340b7eab1d14..3c89b56d660066 100644 --- a/python/paddle/base/dygraph/base.py +++ b/python/paddle/base/dygraph/base.py @@ -118,6 +118,8 @@ def _convert_into_variable(tensor): """ Convert Tensor into Variable. """ + if paddle.framework.use_pir_api(): + return paddle.pir.core._convert_into_opresult(tensor) if isinstance(tensor, core.eager.Tensor): # Check whether has been created before. new_var = tensor.block._find_var_recursive(tensor.name) diff --git a/python/paddle/jit/dy2static/newir_partial_program.py b/python/paddle/jit/dy2static/newir_partial_program.py index 5bd7216b54b62c..198cc105b3ec14 100644 --- a/python/paddle/jit/dy2static/newir_partial_program.py +++ b/python/paddle/jit/dy2static/newir_partial_program.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os from copy import deepcopy @@ -27,9 +28,9 @@ from paddle.base.data_feeder import check_type, convert_dtype from paddle.base.dygraph.base import switch_to_static_graph from paddle.base.framework import _apply_pass -from paddle.base.libpaddle.pir import OpResult, fake_op_result from paddle.framework import use_pir_api from paddle.optimizer.lr import LRScheduler +from paddle.pir import OpResult, fake_op_result, is_fake_op_result from . import logging_utils from .utils import RETURN_NO_VALUE_MAGIC_NUM, backend_guard @@ -175,7 +176,9 @@ def __init__( super().__init__() self._inputs = NestSequence(inputs) self._outputs = NestSequence(outputs, need_check=True) - self._params = parameters if parameters is not None else [] + self._params, self._param_values = ( + parameters if parameters is not None else ([], []) + ) self._build_strategy = kwargs.get('build_strategy', BuildStrategy()) assert isinstance(self._build_strategy, BuildStrategy) @@ -654,6 +657,7 @@ def _append_backward_desc(self, main_program): inputs = list( filter(lambda x: isinstance(x, OpResult), self._inputs.tolist()) ) + combined_inputs = list(itertools.chain(inputs, self._param_values)) forward_end_idx = len(program.global_block().ops) if targets: with backend_guard(self._backend): @@ -664,8 +668,9 @@ def _append_backward_desc(self, main_program): 'paddle.static.gradients', ) with ir_static.program_guard(program, None): - grad_info_map = grad(inputs=inputs, outputs=targets) - + grad_info_map = grad( + inputs=combined_inputs, outputs=targets + ) forward_outputs_grads = [] not_stop_gradient_num = 0 for out_op_result in self._outputs.tolist(): @@ -703,8 +708,12 @@ def _append_backward_desc(self, main_program): extra_info['forward_inputs'] = inputs extra_info['forward_outputs'] = targets extra_info['forward_end_op_idx'] = forward_end_idx + inputs_size = len(inputs) extra_info['forward_inputs_grads'] = list( - map(mapping_op_result, grad_info_map) + map(mapping_op_result, grad_info_map[0:inputs_size]) + ) + extra_info['forward_params_grads'] = list( + map(mapping_op_result, grad_info_map[inputs_size:]) ) extra_info['forward_outputs_grads'] = list( map(mapping_op_result, forward_outputs_grads) @@ -722,21 +731,15 @@ def _prune_unused_params(self, program): `run_program_op`. """ required_params = [] - for param in self._params: - found_param = False - for block in program.blocks: - for op in block.ops: - if ( - param.name in op.input_arg_names - or param.name in op.output_arg_names - ): - required_params.append(param) - found_param = True - break - if found_param: - break + required_param_values = [] + block = program.global_block() + for param, param_value in zip(self._params, self._param_values): + if not param_value.use_empty(): + required_params.append(param) + required_param_values.append(param_value) self._params = required_params + self._param_values = required_param_values def _cast_fp16_if_pure_fp16(self, in_vars): if _in_pure_fp16_guard(): @@ -805,9 +808,13 @@ def _get_forward_backward_program_form( forward_outputs = self.get_program_extra(whole_program)[ 'forward_outputs' ] + forward_parameters = self._param_values forward_outputs_grads = self.get_program_extra(whole_program)[ 'forward_outputs_grads' ] + forward_params_grads = self.get_program_extra(whole_program)[ + 'forward_params_grads' + ] backward_start_op_index = forward_end_op_index + 2 * len( list(filter(lambda r: r.stop_gradient is False, self._outputs)) ) @@ -819,15 +826,16 @@ def _get_forward_backward_program_form( # backward_skip_vars = self._parse_skip_gc_vars( # whole_program # ) + self._grad_var_names.get('param', []) - ( forward_program, backward_program, ), program_attr = paddle.base.libpaddle.pir.program_split( whole_program, forward_inputs, + forward_parameters, forward_outputs, forward_inputs_grads, + forward_params_grads, forward_outputs_grads, [0, forward_end_op_index], [backward_start_op_index, backward_end_op_index], @@ -957,7 +965,6 @@ def create_out(var_id): tensor_type = paddle.dtype(7) # LOD TENSOR else: tensor_type = paddle.dtype(8) # SELECT ROW TENSOR - out = core.eager.Tensor( framework.paddle_type_to_proto_type[var.dtype], var.shape, @@ -1063,15 +1070,24 @@ def _set_grad_type(self, params, train_program): # If we don't change grad_var type here, RunProgramOp need # transform SelectedRows to LoDTensor forcibly, it may not # be user wanted result. - for param in params: - grad_name = param.name + core.grad_var_suffix() - grad_var = train_program.desc.global_block().find_var( - grad_name.encode() - ) - # NOTE: cannot find var desc maybe no problem, such as in batch_norm - if grad_var is None: + forward_params_grads = self.get_program_extra(train_program)[ + 'forward_params_grads' + ] + for param, value in zip(params, forward_params_grads): + if is_fake_op_result(value): continue - param._set_grad_type(grad_var.type()) + if value.is_selected_row_type(): + param._set_grad_type( + paddle.base.core.VarDesc.VarType.SELECTED_ROWS + ) + elif value.is_dense_tensor_type(): + param._set_grad_type( + paddle.base.core.VarDesc.VarType.LOD_TENSOR + ) + else: + raise NotImplementedError( + "only support selected_row and dense_tensor grad type." + ) def _remove_op_call_stack(self, main_program): """ @@ -1134,12 +1150,4 @@ def partial_program_from(concrete_program, from_method=False): def add_build_strategy_for( program, start_op_index, end_op_index, build_strategy=None, skip_vars=None ): - paddle.base.libpaddle.pir.program_split( - program, - ) - if start_op_index < end_op_index: - pass - else: - # can't just create a new program, we need copy the vardesc. - builded_program = ir_static.Program() - return builded_program + raise NotImplementedError("Not implemented yet.") diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 42d0049f8a3689..7ebcf09de9c8ec 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1231,12 +1231,13 @@ def newir_from_func_spec( raise # 3. Gets all ParamBases and buffered VarBases in the function - all_parameters_and_buffers = ( - ProgramTranslator.get_instance()._params_recorder.pop( - main_program - ) + from ..newir_dy2static.parameter_recorder import ( + _global_parameter_recorder, ) + all_parameters_and_buffers = _global_parameter_recorder.pop( + main_program + ) if outputs is not None: need_wrap_into_list = ( not isinstance(outputs, (tuple, list)) diff --git a/python/paddle/jit/newir_dy2static/__init__.py b/python/paddle/jit/newir_dy2static/__init__.py new file mode 100644 index 00000000000000..595add0aed9e11 --- /dev/null +++ b/python/paddle/jit/newir_dy2static/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/jit/newir_dy2static/parameter_recorder.py b/python/paddle/jit/newir_dy2static/parameter_recorder.py new file mode 100644 index 00000000000000..2bebff160c20ec --- /dev/null +++ b/python/paddle/jit/newir_dy2static/parameter_recorder.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.base import framework + +from ..dy2static.program_translator import _program_hash, synchronized + + +class ParametersRecorder: + def __init__(self): + self.params_dict = {} + self.tensor2opresult = {} + + @synchronized + def get(self, program, tensor): + from paddle.pir.core import create_parameter, vartype_to_datatype + + """use the default_program as key, append tensor the parameter list.""" + key = _program_hash(program) + if key not in self.params_dict: + self.params_dict[key] = set() + self.tensor2opresult[key] = {} + + params = self.params_dict[key] + mappings = self.tensor2opresult[key] + if id(tensor) not in mappings: + non_used_initializer = paddle.nn.initializer.Constant(0.0) + op_result = create_parameter( + dtype=vartype_to_datatype[tensor.dtype], + shape=tensor.shape, + type=tensor.type, + initializer=non_used_initializer, + ) + if isinstance(tensor, framework.EagerParamBase): + params.add(tensor) + mappings[id(tensor)] = op_result + return mappings[id(tensor)] + + def pop(self, program): + hash_id = _program_hash(program) + params = self.params_dict.get(hash_id) + if params is None: + return [], [] + params_values = [ + self.tensor2opresult[hash_id][id(x)] for x in list(params) + ] + del self.params_dict[hash_id] + del self.tensor2opresult[hash_id] + return list(params), list(params_values) + + +_global_parameter_recorder = ParametersRecorder() diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index 07588983d64e4e..39b8c71ca5a2f4 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -19,6 +19,8 @@ Value, OpOperand, OpResult, + fake_op_result, + is_fake_op_result, Type, ) # noqa: F401 from paddle.base.libpaddle.pir import ( diff --git a/python/paddle/pir/core.py b/python/paddle/pir/core.py index 51a661186cf4f8..2fcf73cd10fa8f 100644 --- a/python/paddle/pir/core.py +++ b/python/paddle/pir/core.py @@ -290,3 +290,33 @@ def create_parameter( param.is_persistable = True return param + + +def _convert_into_opresult(tensor): + """ + Convert Tensor into OpResult. + """ + import paddle + from paddle.base import core, framework + from paddle.jit.newir_dy2static.parameter_recorder import ( + _global_parameter_recorder, + ) + + if isinstance(tensor, core.eager.Tensor): + # Check whether has been created before. + new_var = tensor.block._find_var_recursive(tensor.name) + is_persistable = True + if new_var is not None: + assert isinstance(new_var, framework.Variable) + elif isinstance(tensor, framework.EagerParamBase): + # Convert EagerParamBase into Parameter with same attributes in dy2stat. + new_var = _global_parameter_recorder.get( + paddle.pir.core.default_main_program(), tensor + ) + else: + # TODO(xiongkun): add this logic, we should call paddle.data() to create a non-parameter variable. + raise NotImplementedError("Not implemented, for buffers.") + # add param into parameter recorder to collect all the params used in this program. + return new_var + else: + return tensor diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 443a7a8c051c9c..aa1f1898b0186c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1984,8 +1984,10 @@ def split(x, num_or_sections, axis=0, name=None): dim = (len(input.shape) + dim) if dim < 0 else dim if isinstance(num_or_sections, int): + dim = dim if dim >= 0 else dim + len(input.shape) return _C_ops.split_with_num(input, num_or_sections, dim) else: + dim = dim if dim >= 0 else dim + len(input.shape) return _C_ops.split(input, num_or_sections, dim) else: diff --git a/setup.py b/setup.py index a507f3271a9d6c..221e0a0770e062 100644 --- a/setup.py +++ b/setup.py @@ -1424,6 +1424,7 @@ def get_setup_parameters(): 'paddle.framework', 'paddle.jit', 'paddle.jit.dy2static', + 'paddle.jit.newir_dy2static', 'paddle.inference', 'paddle.inference.contrib', 'paddle.inference.contrib.utils', diff --git a/test/ir/new_ir/test_new_ir_to_static.py b/test/ir/new_ir/test_new_ir_to_static.py index aadffa2cd08076..5516b3bca04c1a 100644 --- a/test/ir/new_ir/test_new_ir_to_static.py +++ b/test/ir/new_ir/test_new_ir_to_static.py @@ -64,6 +64,27 @@ def func(x): ) +class TestDy2staticNewIR2(unittest.TestCase): + def test_basic_layer(self): + class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + net = SimpleNet() + x = paddle.randn((10, 10)) + x.stop_gradient = False + ans = net(x) + net = paddle.jit.to_static(net) + out = net(x) + np.testing.assert_allclose( + out.numpy(), ans.numpy(), rtol=1e-05, atol=1e-8 + ) + + class TestDy2staticNewIR3(unittest.TestCase): def test_complex_layer(self): def output_pure_func(x, y): @@ -73,11 +94,7 @@ def output_pure_func(x, y): return paddle.add(outx, outy), outy def run_function(to_static=True): - import paddle - - # 设置随机种子 paddle.seed(2023) - # 生成随机数 x = paddle.randn((10, 10)) y = paddle.randn((10, 10)) x.stop_gradient = False @@ -96,5 +113,87 @@ def run_function(to_static=True): ) +class TestLossFor10Steps(unittest.TestCase): + def test_loss_for_10_steps(self): + # Dy2static RunProgramOp support nn.Layer's forward and backward training. + class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + def train_step(to_static=True): + paddle.seed(2023) + x = paddle.randn((10, 10), dtype='float32') + y = paddle.randn((10, 10), dtype='float32') + loss_fn = paddle.nn.loss.MSELoss() + net = SimpleNet() + optimizer = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + if to_static: + net = paddle.jit.to_static(net) + losses = [] + for step in range(100): + y_pred = net(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + optimizer.clear_grad() + losses.append(loss.numpy()) + return losses + + expected_losses = train_step(True) + losses = train_step(False) + np.testing.assert_allclose( + losses, expected_losses, rtol=1e-05, atol=1e-8 + ) + + +class TestDy2staticNewIR5(unittest.TestCase): + def test_run(self): + # Dy2static RunProgramOp support nn.Layer's forward and backward training. + class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(10, 10) + + def forward(self, x, y): + if y is True: + return self.linear(x) + else: + m = self.linear(x) + return m * m + + def train_step(to_static=True): + paddle.seed(2023) + x = paddle.randn((10, 10), dtype='float32') + y = paddle.randn((10, 10), dtype='float32') + loss_fn = paddle.nn.loss.MSELoss() + net = SimpleNet() + optimizer = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + if to_static: + net = paddle.jit.to_static(net) + losses = [] + for step in range(100): + y_pred = net(x, step % 2 == 1) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + optimizer.clear_grad() + losses.append(loss.numpy()) + return losses + + expected_losses = train_step(True) + losses = train_step(False) + np.testing.assert_allclose( + losses, expected_losses, rtol=1e-05, atol=1e-8 + ) + + if __name__ == "__main__": unittest.main()