Skip to content

Commit

Permalink
[onert] Introduce UseDefGenerator (#13374)
Browse files Browse the repository at this point in the history
This commit introduces UseDefGenerator that generates UseDefChain for each opeartion from a TrainableGraph.
   - Introduce UseDefGenerator
   - Add a visit method for backwarding Loss op
   - Add helper methods
   - insertUse : insert an opeartion index as use into UseDefChain for an operand
   - insertDef : insert an operation index as def into UseDefChain for an operand
   - insertBackPropDef : insert use similarly insertUse, but only insert it if operand is not a constant
   - initForForwardingNodes : insert UseDefChains for all forwarding operation nodes.
   - initForBackwardNodes : insert UseDefChains for all backwarding operation nodes.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 17, 2024
1 parent 715c2b1 commit dbbeceb
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 0 deletions.
187 changes: 187 additions & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "UseDefGenerator.h"

#include "ir/train/TrainableGraph.h"
#include "ir/train/Index.h"
#include "../verifier/Verifier.h"

#include <cassert>
#include <memory>

// TODO Reduce duplicate code

namespace onert
{
namespace ir
{
namespace train
{

UseDefGenerator::UseDefGenerator(const TrainableGraph &tgraph)
: _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{}
{
const auto order = _tgraph.topolSortOperations();
for (const auto &index : order)
{
const auto &node = _tgraph.operation(index);
assert(_node_to_idx.find(&node) == _node_to_idx.end());
_node_to_idx[&node] = index;
}

// Check whether loss exists
assert(std::any_of(order.begin(), order.end(),
[&](const auto &index) {
return _tgraph.operation(index).opcode() == ir::OpCode::Loss;
}) &&
"Loss does not exist");
}

UseDefChains UseDefGenerator::operator()()
{
const auto &graph = _tgraph.graph();
assert(ir::verifier::EdgeChecker().verify(graph));

_training_usedefs.clear();
graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) {
// Initialize as emtpy UseDefChain
const auto empty_usedef_chain = UseDefChain{operand};
_training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain);
_training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain);
});

initForForwardingNodes();

initForBackwardingNodes();

return _training_usedefs;
}

void UseDefGenerator::visit(const train::operation::Loss &node)
{
assert(_node_to_idx.find(&node) != _node_to_idx.end());
const auto &op_index = _node_to_idx.at(&node);
const auto backwarding_op_index = TrainingOperationIndex{op_index, false};

for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
{
// Insert use of forwarding inputs
const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
insertUse(in_forwarding_index, backwarding_op_index);
}

// Set def of backwarding(backprop) y_pred
const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED);
assert(!_tgraph.operands().at(y_pred_index).isConstant());
const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false};
insertBackPropDef(y_pred_outgoing_index, backwarding_op_index);

// Set def of backwarding(backprop) y_true
const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE);
assert(!_tgraph.operands().at(y_true_index).isConstant());
const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false};
insertBackPropDef(y_true_outgoing_index, backwarding_op_index);

// Remove use of backwarding output
const auto &out_index = node.getOutputs().at(0);
const auto incoming_index = TrainingOperandIndex{out_index, false};
auto &usedef_chain = _training_usedefs.at(incoming_index);
usedef_chain.removeTrainingUse(backwarding_op_index);
}

void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index,
const TrainingOperationIndex &op_index)
{
assert(_training_usedefs.find(operand_index) != _training_usedefs.end());
auto &usedef_chain = _training_usedefs.at(operand_index);
usedef_chain.insertTrainingUse(op_index);
}

void UseDefGenerator::insertDef(const TrainingOperandIndex &operand_index,
const TrainingOperationIndex &op_index)
{
assert(operand_index.valid());

assert(_training_usedefs.find(operand_index) != _training_usedefs.end());
auto &usedef_chain = _training_usedefs.at(operand_index);
usedef_chain.insertTrainingDef(op_index);
}

void UseDefGenerator::insertBackPropDef(const TrainingOperandIndex &operand_index,
const TrainingOperationIndex &op_index)
{
// NOTE There is no need to set def of constant backwarding(backprop) inputs
// because it won't be back-propagated.
if (!_tgraph.operands().at(operand_index.index()).isConstant())
{
insertDef(operand_index, op_index);
}
}

void UseDefGenerator::initForForwardingNodes()
{
// Initialize training def-uses of forwarding operands for only forwarding nodes
// (i.e. forwarding nodes that do not have any backwarding node)
_tgraph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) {
// Append forwarding def-uses as it is
const bool is_forward = true;
const auto forwarding_operand_index = TrainingOperandIndex{idx, is_forward};

const auto def = operand.getDef();
if (def.valid())
{
insertDef(forwarding_operand_index, TrainingOperationIndex{def, is_forward});
auto &usedef_chain = _training_usedefs.at(forwarding_operand_index);
usedef_chain.insertTrainingDef(TrainingOperationIndex{def, is_forward});
}

assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0);
const auto uses = operand.getUses();
for (const auto &use : uses)
insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward});
});
}

void UseDefGenerator::initForBackwardingNodes()
{
const auto backward_order = _tgraph.essentialBackwardOrder();
// Initialize training uses of forwarding operands and def-uses of backwarding operands for
// backwarding nodes (i.e. backwarding nodes that do not have any forwarding node)
for (const auto &op_index : backward_order)
{
const auto &node = _tgraph.operation(op_index);

// Insert use of backwarding operands(only output)
{
if (node.getOutputs().size() > 1)
throw std::runtime_error(
"UseDefGenerator does not support multiple outputs of training operation");

const auto &output = node.getOutputs().at(0);
const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
const auto incoming_index = TrainingOperandIndex{output, false};
insertUse(incoming_index, backwarding_op_index);
}

// Insert uses of forwarding operands and insert defs of backwarding operands
node.accept(*this);
}
}

} // namespace train
} // namespace ir
} // namespace onert
87 changes: 87 additions & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__
#define __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__

#include "ir/train/TrainableOperationVisitor.h"

#include "ir/train/UseDefChains.h"
#include "ir/train/Operations.Include.h"

namespace onert
{
namespace ir
{
namespace train
{
class TrainableGraph;
} // namespace train
} // namespace ir
} // namespace onert

namespace onert
{
namespace ir
{
namespace train
{

struct UseDefGeneratorBase : public TrainableOperationVisitor
{
virtual ~UseDefGeneratorBase() = default;

protected:
#define OP(InternalName) \
virtual void visit(const operation::InternalName &) override \
{ \
throw std::runtime_error("UseDefGenerator: NYI for operation '" #InternalName "'"); \
}
#include "ir/train/Operations.lst"
#undef OP
};

class UseDefGenerator : public UseDefGeneratorBase
{
public:
UseDefGenerator(void) = delete;
UseDefGenerator(const TrainableGraph &tgraph);

public:
UseDefChains operator()();

public:
void visit(const train::operation::Loss &node) override;

private:
void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);
void insertDef(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);
void insertBackPropDef(const TrainingOperandIndex &operand_index,
const TrainingOperationIndex &op_index);
void initForForwardingNodes();
void initForBackwardingNodes();

private:
const TrainableGraph &_tgraph;
std::unordered_map<const ITrainableOperation *, OperationIndex> _node_to_idx;
UseDefChains _training_usedefs;
};

} // namespace train
} // namespace ir
} // namespace onert

#endif // __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__

0 comments on commit dbbeceb

Please sign in to comment.