From 44d8dde71e003c69f9921fba493c9a6949c4ce3d Mon Sep 17 00:00:00 2001 From: SeungHui Youn <61981457+zetwhite@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:27:55 +0900 Subject: [PATCH] [luci/pass] Trim RemoveUnnecessaryTransposesNetPass (#14187) This PR trims RemoveUnnecessaryTransposesNetPass. - enhance some comements - rename a varaible - replace 'if(not cond) return false' with macro ONE-DCO-1.0-Signed-off-by: seunghui youn --- .../src/RemoveUnnecessaryTransposeNetPass.cpp | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index eafb9f3f1aa..5166bdde78b 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -191,6 +191,8 @@ void TaggedShapeAnalyzer::analyze_transpose(const std::vector &perm) * * @return False, if it fails to update _shape * + * @note It only support analyzing reshape that combines N consecutive dims into one dims. + * * @example Let's assume new_shape={1, 448, 49} is given to [before] _shape. * * [before] _shape : @@ -233,16 +235,16 @@ bool TaggedShapeAnalyzer::analyze_reshape(const std::vector &new_shape) // Create 'new_tagged_shape' based on 'new_shape' TagShape new_tagged_shape; - uint32_t new_shape_idx = 0; - while (new_shape_idx < new_shape.size()) + uint32_t target_idx = 0; + while (target_idx < new_shape.size()) { - auto target_dim = new_shape[new_shape_idx]; + auto target_dim = new_shape[target_idx]; // Ignore dim == 1 if (target_dim == 1) { new_tagged_shape.emplace_back(1); - new_shape_idx++; + target_idx++; continue; } @@ -265,7 +267,7 @@ bool TaggedShapeAnalyzer::analyze_reshape(const std::vector &new_shape) } new_tagged_shape.push_back(dim); - new_shape_idx++; + target_idx++; move_to_next_range(); } _shape = new_tagged_shape; @@ -293,8 +295,8 @@ bool TaggedShapeAnalyzer::verify_tag() const * @brief Initialize the class members and check under conditions * * Condtiions that have to be met for analyzer - * c2: The 'perm' of tranpose should be a CircleConst* type - * c3: All extracted shapes (named as '*_shape_v' in member variable) should be known + * c1: The 'perm' of tranpose should be a CircleConst* type + * c2: All extracted shapes (named as '*_shape_v' in member variable) should be known * * @return True, if all conditions are satisfied and class members are initialized successfully * False, otherwise @@ -311,7 +313,7 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, const auto front_perm = dynamic_cast(_front_transpose->perm()); const auto back_perm = dynamic_cast(_back_transpose->perm()); - // check c2 + // check c1 CHECK_OR_FALSE(front_perm != nullptr); CHECK_OR_FALSE(back_perm != nullptr); @@ -329,7 +331,7 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, return true; }; - // check c3 + // check c2 CHECK_OR_FALSE(all_known(_in_shape_v)); CHECK_OR_FALSE(all_known(_front_shape_v)); CHECK_OR_FALSE(all_known(_mid_shape_v)); @@ -467,24 +469,16 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node) { // find 'front_transpose - mid_reshape - back_transpose' pattern const auto back_transpose = node; + const auto mid_reshape = dynamic_cast(back_transpose->a()); - { - if (mid_reshape == nullptr) - return false; - } + CHECK_OR_FALSE(mid_reshape != nullptr); + const auto front_transpose = dynamic_cast(mid_reshape->tensor()); - { - if (not front_transpose) - return false; - } + CHECK_OR_FALSE(front_transpose != nullptr); TaggedShapeAnalyzer analyzer; - - if (not analyzer.init(front_transpose, mid_reshape, back_transpose)) - return false; - - if (not analyzer.can_remove_transposes()) - return false; + CHECK_OR_FALSE(analyzer.init(front_transpose, mid_reshape, back_transpose)); + CHECK_OR_FALSE(analyzer.can_remove_transposes()); // repalce with new_node luci::CircleReshape *new_node = create_reshape_node(