From 2b2395aabebb2bd81884c1b6ac2ef190e5c18ae1 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 23 Jan 2025 21:05:20 +0900 Subject: [PATCH] [luci/pass] Add unittest to ExpandBroadcastConstPass (#14581) Let's add more unittest to ExpandBroadcastConstPass with axis=1 (non-lastdim). ONE-DCO-Signed-off-by: Dayoung Lee --- .../src/ExpandBroadcastConstPass.test.cpp | 111 +++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp index 1eb985f44e8..a33604c01ed 100644 --- a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp +++ b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp @@ -16,6 +16,7 @@ #include "luci/Pass/ExpandBroadcastConstPass.h" #include "PassTestGraphs.h" +#include "helpers/ArrayIndex.h" #include @@ -114,6 +115,24 @@ TEST_F(ExpandBroadcastRank2ConstTest, remove_broadcast) } } +TEST_F(ExpandBroadcastRank2ConstTest, broadcast_impossible_NEG) +{ + _y->shape({N, D + 1}); + _y->size(N * (D + 1)); + + luci::ExpandBroadcastConstPass pass; + ASSERT_FALSE(pass.run(&_g)); +} + +TEST_F(ExpandBroadcastRank2ConstTest, broadcast_diff_rank_NEG) +{ + _y->shape({N}); + _y->size(N); + + luci::ExpandBroadcastConstPass pass; + ASSERT_FALSE(pass.run(&_g)); +} + /**************************************************************************** * TESTS FOR RANK 4 ****************************************************************************/ @@ -181,7 +200,17 @@ class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, pu } }; -// TODO: Add more tests for Rank4 with different broadcasting dimensions +class ExpandBroadcastRank4ConstTest2 : public ExpandBroadcastConstRank4Graph, public ::testing::Test +{ +public: + ExpandBroadcastRank4ConstTest2() + { + _y->dtype(loco::DataType::FLOAT32); + _y->shape({N, 1, W, D}); + _y->size(N * 1 * W * D); + } +}; + } // namespace TEST_F(ExpandBroadcastRank4ConstTest1, name) @@ -252,3 +281,83 @@ TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG) luci::ExpandBroadcastConstPass pass; ASSERT_FALSE(pass.run(&_g)); } + +TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_diff_rank_NEG) +{ + _y->shape({N, H, W}); + _y->size(N * H * W); + + luci::ExpandBroadcastConstPass pass; + ASSERT_FALSE(pass.run(&_g)); +} + +TEST_F(ExpandBroadcastRank4ConstTest2, remove_broadcast) +{ + for (uint32_t i = 0; i < N * W * D; ++i) + _y->at(i) = static_cast(i); + + luci::ExpandBroadcastConstPass pass; + ASSERT_TRUE(pass.run(&_g)); + + auto broadcasted_const = dynamic_cast(_add->y()); + ASSERT_NE(broadcasted_const, nullptr); + + EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(broadcasted_const->dim(0).value(), N); + EXPECT_EQ(broadcasted_const->dim(1).value(), H); + EXPECT_EQ(broadcasted_const->dim(2).value(), W); + EXPECT_EQ(broadcasted_const->dim(3).value(), D); + EXPECT_EQ(broadcasted_const->size(), N * H * W * D); + + auto const idx = luci::Array4DIndex(N, H, W, D); + + for (uint32_t n = 0; n < N; ++n) + for (uint32_t h = 0; h < H; ++h) + for (uint32_t w = 0; w < W; ++w) + for (uint32_t d = 0; d < D; ++d) + EXPECT_NEAR(broadcasted_const->at(idx(n, h, w, d)), + static_cast(n * W * D + w * D + d), std::numeric_limits::min()); +} + +TEST_F(ExpandBroadcastRank4ConstTest2, remove_broadcast_multiple_successors) +{ + auto const circle_sqrt = _g.nodes()->create(); + circle_sqrt->dtype(loco::DataType::FLOAT32); + circle_sqrt->shape({N, 1, W, D}); + circle_sqrt->x(_y); + + luci::ExpandBroadcastConstPass pass; + ASSERT_TRUE(pass.run(&_g)); + + auto broadcasted_const = dynamic_cast(_add->y()); + auto original_const = dynamic_cast(circle_sqrt->x()); + + ASSERT_NE(broadcasted_const, nullptr); + EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(broadcasted_const->dim(1).value(), H); + EXPECT_EQ(broadcasted_const->size(), N * H * W * D); + + // Check if another successor's node was left intact + ASSERT_NE(original_const, nullptr); + EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(original_const->dim(1).value(), 1); + EXPECT_EQ(original_const->size(), N * 1 * W * D); +} + +TEST_F(ExpandBroadcastRank4ConstTest2, broadcast_impossible_NEG) +{ + _y->shape({N, H, W + 1, D + 1}); + _y->size(N * H * (W + 1) * (D + 1)); + + luci::ExpandBroadcastConstPass pass; + ASSERT_FALSE(pass.run(&_g)); +} + +TEST_F(ExpandBroadcastRank4ConstTest2, broadcast_diff_rank_NEG) +{ + _y->shape({N, H, W + 4}); + _y->size(N * H * (W + 4)); + + luci::ExpandBroadcastConstPass pass; + ASSERT_FALSE(pass.run(&_g)); +}