From 157fd42b63c8c51fb235c948b082016114cb68a9 Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Wed, 2 Oct 2024 04:03:45 +0000 Subject: [PATCH] [luci/svc] Handle 0 in shape for Reshape Op shape inference This will fix to handle 0 in shape for Reshape Op shape inference. ONE-DCO-1.0-Signed-off-by: SaeHie Park --- .../luci/service/src/Nodes/CircleReshape.cpp | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 28eb6303735..553e1eabd5d 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -22,6 +22,8 @@ #include +#include + namespace { @@ -88,11 +90,29 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) { - shape_by_input.dim(axis) = const_shape_node->at(axis); if (const_shape_node->at(axis) < 0) { shape_by_input.dim(axis).unset(); } + else if (const_shape_node->at(axis) == 0) + { + const auto node_tensor = loco::must_cast(node->tensor()); + // set dim value to input + if (node_tensor->shape_status() == luci::ShapeStatus::VALID && axis < node_tensor->rank()) + shape_by_input.dim(axis) = node_tensor->dim(axis); + else + { + // stop to check if this case exist for debugging + INTERNAL_EXN("Check Reshape shape with 0"); + } + } + else + { + shape_by_input.dim(axis).set(const_shape_node->at(axis)); + } + // check valid or stop for debugging + LUCI_ASSERT(shape_by_input.dim(axis).value() > 0 || !shape_by_input.dim(axis).known(), + "Reshape infer shape is invalid."); } } else @@ -143,7 +163,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) { for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) { - const uint32_t dim_value = output_shape.dim(dim_index).value(); + uint32_t dim_value = output_shape.dim(dim_index).value(); if (not output_shape.dim(dim_index).known()) { LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); @@ -151,6 +171,18 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) } else { + if (!dim_value) + { + // refer https://github.com/Samsung/ONE/issues/14074#issuecomment-2370795003 + // set dim value to follow input + if (dim_index < input_shape.rank()) + dim_value = input_shape.dim(dim_index).value(); + else + { + // stop to check if this case exist for debugging + INTERNAL_EXN("Check Reshape shape with 0"); + } + } output_element_count *= dim_value; } }