Skip to content

Commit

Permalink
[luci/pass] Trim RemoveUnnecessaryTransposesNetPass (#14187)
Browse files Browse the repository at this point in the history
This PR trims RemoveUnnecessaryTransposesNetPass.
- enhance some comements
- rename a varaible
- replace 'if(not cond) return false' with macro

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>
  • Loading branch information
zetwhite authored Oct 10, 2024
1 parent 8f5bcfb commit 44d8dde
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ void TaggedShapeAnalyzer::analyze_transpose(const std::vector<int32_t> &perm)
*
* @return False, if it fails to update _shape
*
* @note It only support analyzing reshape that combines N consecutive dims into one dims.
*
* @example Let's assume new_shape={1, 448, 49} is given to [before] _shape.
*
* [before] _shape :
Expand Down Expand Up @@ -233,16 +235,16 @@ bool TaggedShapeAnalyzer::analyze_reshape(const std::vector<int32_t> &new_shape)
// Create 'new_tagged_shape' based on 'new_shape'
TagShape new_tagged_shape;

uint32_t new_shape_idx = 0;
while (new_shape_idx < new_shape.size())
uint32_t target_idx = 0;
while (target_idx < new_shape.size())
{
auto target_dim = new_shape[new_shape_idx];
auto target_dim = new_shape[target_idx];

// Ignore dim == 1
if (target_dim == 1)
{
new_tagged_shape.emplace_back(1);
new_shape_idx++;
target_idx++;
continue;
}

Expand All @@ -265,7 +267,7 @@ bool TaggedShapeAnalyzer::analyze_reshape(const std::vector<int32_t> &new_shape)
}
new_tagged_shape.push_back(dim);

new_shape_idx++;
target_idx++;
move_to_next_range();
}
_shape = new_tagged_shape;
Expand Down Expand Up @@ -293,8 +295,8 @@ bool TaggedShapeAnalyzer::verify_tag() const
* @brief Initialize the class members and check under conditions
*
* Condtiions that have to be met for analyzer
* c2: The 'perm' of tranpose should be a CircleConst* type
* c3: All extracted shapes (named as '*_shape_v' in member variable) should be known
* c1: The 'perm' of tranpose should be a CircleConst* type
* c2: All extracted shapes (named as '*_shape_v' in member variable) should be known
*
* @return True, if all conditions are satisfied and class members are initialized successfully
* False, otherwise
Expand All @@ -311,7 +313,7 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,
const auto front_perm = dynamic_cast<luci::CircleConst *>(_front_transpose->perm());
const auto back_perm = dynamic_cast<luci::CircleConst *>(_back_transpose->perm());

// check c2
// check c1
CHECK_OR_FALSE(front_perm != nullptr);
CHECK_OR_FALSE(back_perm != nullptr);

Expand All @@ -329,7 +331,7 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,
return true;
};

// check c3
// check c2
CHECK_OR_FALSE(all_known(_in_shape_v));
CHECK_OR_FALSE(all_known(_front_shape_v));
CHECK_OR_FALSE(all_known(_mid_shape_v));
Expand Down Expand Up @@ -467,24 +469,16 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node)
{
// find 'front_transpose - mid_reshape - back_transpose' pattern
const auto back_transpose = node;

const auto mid_reshape = dynamic_cast<luci::CircleReshape *>(back_transpose->a());
{
if (mid_reshape == nullptr)
return false;
}
CHECK_OR_FALSE(mid_reshape != nullptr);

const auto front_transpose = dynamic_cast<luci::CircleTranspose *>(mid_reshape->tensor());
{
if (not front_transpose)
return false;
}
CHECK_OR_FALSE(front_transpose != nullptr);

TaggedShapeAnalyzer analyzer;

if (not analyzer.init(front_transpose, mid_reshape, back_transpose))
return false;

if (not analyzer.can_remove_transposes())
return false;
CHECK_OR_FALSE(analyzer.init(front_transpose, mid_reshape, back_transpose));
CHECK_OR_FALSE(analyzer.can_remove_transposes());

// repalce with new_node
luci::CircleReshape *new_node = create_reshape_node<loco::DataType::S32>(
Expand Down

0 comments on commit 44d8dde

Please sign in to comment.