-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert] Introduce UseDefGenerator (#13374)
This commit introduces UseDefGenerator that generates UseDefChain for each opeartion from a TrainableGraph. - Introduce UseDefGenerator - Add a visit method for backwarding Loss op - Add helper methods - insertUse : insert an opeartion index as use into UseDefChain for an operand - insertDef : insert an operation index as def into UseDefChain for an operand - insertBackPropDef : insert use similarly insertUse, but only insert it if operand is not a constant - initForForwardingNodes : insert UseDefChains for all forwarding operation nodes. - initForBackwardNodes : insert UseDefChains for all backwarding operation nodes. ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
2 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "UseDefGenerator.h" | ||
|
||
#include "ir/train/TrainableGraph.h" | ||
#include "ir/train/Index.h" | ||
#include "../verifier/Verifier.h" | ||
|
||
#include <cassert> | ||
#include <memory> | ||
|
||
// TODO Reduce duplicate code | ||
|
||
namespace onert | ||
{ | ||
namespace ir | ||
{ | ||
namespace train | ||
{ | ||
|
||
UseDefGenerator::UseDefGenerator(const TrainableGraph &tgraph) | ||
: _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{} | ||
{ | ||
const auto order = _tgraph.topolSortOperations(); | ||
for (const auto &index : order) | ||
{ | ||
const auto &node = _tgraph.operation(index); | ||
assert(_node_to_idx.find(&node) == _node_to_idx.end()); | ||
_node_to_idx[&node] = index; | ||
} | ||
|
||
// Check whether loss exists | ||
assert(std::any_of(order.begin(), order.end(), | ||
[&](const auto &index) { | ||
return _tgraph.operation(index).opcode() == ir::OpCode::Loss; | ||
}) && | ||
"Loss does not exist"); | ||
} | ||
|
||
UseDefChains UseDefGenerator::operator()() | ||
{ | ||
const auto &graph = _tgraph.graph(); | ||
assert(ir::verifier::EdgeChecker().verify(graph)); | ||
|
||
_training_usedefs.clear(); | ||
graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) { | ||
// Initialize as emtpy UseDefChain | ||
const auto empty_usedef_chain = UseDefChain{operand}; | ||
_training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain); | ||
_training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain); | ||
}); | ||
|
||
initForForwardingNodes(); | ||
|
||
initForBackwardingNodes(); | ||
|
||
return _training_usedefs; | ||
} | ||
|
||
void UseDefGenerator::visit(const train::operation::Loss &node) | ||
{ | ||
assert(_node_to_idx.find(&node) != _node_to_idx.end()); | ||
const auto &op_index = _node_to_idx.at(&node); | ||
const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; | ||
|
||
for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) | ||
{ | ||
// Insert use of forwarding inputs | ||
const auto in_forwarding_index = TrainingOperandIndex{in_index, true}; | ||
insertUse(in_forwarding_index, backwarding_op_index); | ||
} | ||
|
||
// Set def of backwarding(backprop) y_pred | ||
const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED); | ||
assert(!_tgraph.operands().at(y_pred_index).isConstant()); | ||
const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false}; | ||
insertBackPropDef(y_pred_outgoing_index, backwarding_op_index); | ||
|
||
// Set def of backwarding(backprop) y_true | ||
const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE); | ||
assert(!_tgraph.operands().at(y_true_index).isConstant()); | ||
const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false}; | ||
insertBackPropDef(y_true_outgoing_index, backwarding_op_index); | ||
|
||
// Remove use of backwarding output | ||
const auto &out_index = node.getOutputs().at(0); | ||
const auto incoming_index = TrainingOperandIndex{out_index, false}; | ||
auto &usedef_chain = _training_usedefs.at(incoming_index); | ||
usedef_chain.removeTrainingUse(backwarding_op_index); | ||
} | ||
|
||
void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index, | ||
const TrainingOperationIndex &op_index) | ||
{ | ||
assert(_training_usedefs.find(operand_index) != _training_usedefs.end()); | ||
auto &usedef_chain = _training_usedefs.at(operand_index); | ||
usedef_chain.insertTrainingUse(op_index); | ||
} | ||
|
||
void UseDefGenerator::insertDef(const TrainingOperandIndex &operand_index, | ||
const TrainingOperationIndex &op_index) | ||
{ | ||
assert(operand_index.valid()); | ||
|
||
assert(_training_usedefs.find(operand_index) != _training_usedefs.end()); | ||
auto &usedef_chain = _training_usedefs.at(operand_index); | ||
usedef_chain.insertTrainingDef(op_index); | ||
} | ||
|
||
void UseDefGenerator::insertBackPropDef(const TrainingOperandIndex &operand_index, | ||
const TrainingOperationIndex &op_index) | ||
{ | ||
// NOTE There is no need to set def of constant backwarding(backprop) inputs | ||
// because it won't be back-propagated. | ||
if (!_tgraph.operands().at(operand_index.index()).isConstant()) | ||
{ | ||
insertDef(operand_index, op_index); | ||
} | ||
} | ||
|
||
void UseDefGenerator::initForForwardingNodes() | ||
{ | ||
// Initialize training def-uses of forwarding operands for only forwarding nodes | ||
// (i.e. forwarding nodes that do not have any backwarding node) | ||
_tgraph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) { | ||
// Append forwarding def-uses as it is | ||
const bool is_forward = true; | ||
const auto forwarding_operand_index = TrainingOperandIndex{idx, is_forward}; | ||
|
||
const auto def = operand.getDef(); | ||
if (def.valid()) | ||
{ | ||
insertDef(forwarding_operand_index, TrainingOperationIndex{def, is_forward}); | ||
auto &usedef_chain = _training_usedefs.at(forwarding_operand_index); | ||
usedef_chain.insertTrainingDef(TrainingOperationIndex{def, is_forward}); | ||
} | ||
|
||
assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0); | ||
const auto uses = operand.getUses(); | ||
for (const auto &use : uses) | ||
insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward}); | ||
}); | ||
} | ||
|
||
void UseDefGenerator::initForBackwardingNodes() | ||
{ | ||
const auto backward_order = _tgraph.essentialBackwardOrder(); | ||
// Initialize training uses of forwarding operands and def-uses of backwarding operands for | ||
// backwarding nodes (i.e. backwarding nodes that do not have any forwarding node) | ||
for (const auto &op_index : backward_order) | ||
{ | ||
const auto &node = _tgraph.operation(op_index); | ||
|
||
// Insert use of backwarding operands(only output) | ||
{ | ||
if (node.getOutputs().size() > 1) | ||
throw std::runtime_error( | ||
"UseDefGenerator does not support multiple outputs of training operation"); | ||
|
||
const auto &output = node.getOutputs().at(0); | ||
const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; | ||
const auto incoming_index = TrainingOperandIndex{output, false}; | ||
insertUse(incoming_index, backwarding_op_index); | ||
} | ||
|
||
// Insert uses of forwarding operands and insert defs of backwarding operands | ||
node.accept(*this); | ||
} | ||
} | ||
|
||
} // namespace train | ||
} // namespace ir | ||
} // namespace onert |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ | ||
#define __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ | ||
|
||
#include "ir/train/TrainableOperationVisitor.h" | ||
|
||
#include "ir/train/UseDefChains.h" | ||
#include "ir/train/Operations.Include.h" | ||
|
||
namespace onert | ||
{ | ||
namespace ir | ||
{ | ||
namespace train | ||
{ | ||
class TrainableGraph; | ||
} // namespace train | ||
} // namespace ir | ||
} // namespace onert | ||
|
||
namespace onert | ||
{ | ||
namespace ir | ||
{ | ||
namespace train | ||
{ | ||
|
||
struct UseDefGeneratorBase : public TrainableOperationVisitor | ||
{ | ||
virtual ~UseDefGeneratorBase() = default; | ||
|
||
protected: | ||
#define OP(InternalName) \ | ||
virtual void visit(const operation::InternalName &) override \ | ||
{ \ | ||
throw std::runtime_error("UseDefGenerator: NYI for operation '" #InternalName "'"); \ | ||
} | ||
#include "ir/train/Operations.lst" | ||
#undef OP | ||
}; | ||
|
||
class UseDefGenerator : public UseDefGeneratorBase | ||
{ | ||
public: | ||
UseDefGenerator(void) = delete; | ||
UseDefGenerator(const TrainableGraph &tgraph); | ||
|
||
public: | ||
UseDefChains operator()(); | ||
|
||
public: | ||
void visit(const train::operation::Loss &node) override; | ||
|
||
private: | ||
void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index); | ||
void insertDef(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index); | ||
void insertBackPropDef(const TrainingOperandIndex &operand_index, | ||
const TrainingOperationIndex &op_index); | ||
void initForForwardingNodes(); | ||
void initForBackwardingNodes(); | ||
|
||
private: | ||
const TrainableGraph &_tgraph; | ||
std::unordered_map<const ITrainableOperation *, OperationIndex> _node_to_idx; | ||
UseDefChains _training_usedefs; | ||
}; | ||
|
||
} // namespace train | ||
} // namespace ir | ||
} // namespace onert | ||
|
||
#endif // __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ |