From 842747f1677d2ab596de3aa9b1838987a49dbcbf Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Wed, 22 Jan 2025 20:35:52 +0000 Subject: [PATCH] #15889: Fix handling of mantissa rounding to respect ties round to even - Add tt_metal gtest to test convert_u32_to_bfp rounding logic for bfp8 - Remove unused functions in tt_metal/api/tt-metalium/bfloat8.hpp --- tests/tt_metal/tt_metal/api/CMakeLists.txt | 1 + .../tt_metal/api/test_blockfloat_common.cpp | 84 +++++++++++++++++++ tt_metal/api/tt-metalium/bfloat8.hpp | 81 ------------------ .../api/tt-metalium/blockfloat_common.hpp | 21 ++++- 4 files changed, 104 insertions(+), 83 deletions(-) create mode 100644 tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp diff --git a/tests/tt_metal/tt_metal/api/CMakeLists.txt b/tests/tt_metal/tt_metal/api/CMakeLists.txt index 187382022b4..4a15c487bac 100644 --- a/tests/tt_metal/tt_metal/api/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/api/CMakeLists.txt @@ -33,6 +33,7 @@ set(UNIT_TESTS_API_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_soc_descriptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_tilize_untilize.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_worker_config_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_blockfloat_common.cpp ) # Define the function to create test executables for each architecture diff --git a/tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp b/tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp new file mode 100644 index 00000000000..fc50c6c62f8 --- /dev/null +++ b/tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +namespace { + +struct ConvertU32ToBfpParams { + float float_input; + uint32_t expected_mantissa; + float expected_float_output; +}; + +void roundtrip_test_for_mantissa_rounding_with_bfp8( + float float_input, uint8_t expected_mantissa, float expected_float_output) { + auto uint32_input = *reinterpret_cast(&float_input); + // Set shared exponent as original float exponent (ie. skip logic for handling shared exponents) + auto shared_exp = uint32_input >> 23 & 0xFF; + + auto output_mantissa = convert_u32_to_bfp(uint32_input, shared_exp, false); + EXPECT_EQ(output_mantissa, expected_mantissa); + + uint32_t uint32_output = convert_bfp_to_u32(tt::DataFormat::Bfp8_b, output_mantissa, shared_exp, false); + float float_output = *reinterpret_cast(&uint32_output); + EXPECT_EQ(float_output, expected_float_output); +}; + +} // namespace + +class ConvertU32ToBfpTests : public ::testing::TestWithParam {}; + +TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithPositiveFloat) { + const auto& params = GetParam(); + roundtrip_test_for_mantissa_rounding_with_bfp8( + params.float_input, params.expected_mantissa, params.expected_float_output); +} + +TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithNegativeFloat) { + const auto& params = GetParam(); + const auto float_input = -1 * params.float_input; + const auto expected_mantissa = params.expected_mantissa | 0x80; + const auto expected_float_output = -1 * params.expected_float_output; + + roundtrip_test_for_mantissa_rounding_with_bfp8(float_input, expected_mantissa, expected_float_output); +} + +INSTANTIATE_TEST_SUITE_P( + BlockfloatCommonTests, + ConvertU32ToBfpTests, + // clang-format off + // See tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp for explanation of rounding + // NOTE: These float values are cherry-picked such that: + // - The mantissa hits the 4 cases for rounding + // - The float values match behaviour of round(float) (assuming same spec of ties round to even) + ::testing::Values( + // Round up always + ConvertU32ToBfpParams{ + .float_input = 64.75, // Mantissa is 0x18000 + .expected_mantissa = 0x41, + .expected_float_output = 65, + }, + // Round down always + ConvertU32ToBfpParams{ + .float_input = 65.25, // Mantissa is 0x28000 + .expected_mantissa = 0x41, + .expected_float_output = 65, + }, + // Tie: round down to nearest even + ConvertU32ToBfpParams{ + .float_input = 64.5, // Mantissa is 0x10000 + .expected_mantissa = 0x40, + .expected_float_output = 64, + }, + // Tie: round up to nearest even + ConvertU32ToBfpParams{ + .float_input = 65.5, // Mantissa is 0x30000 + .expected_mantissa = 0x42, + .expected_float_output = 66, + } + ) // Values + // clang-format on +); diff --git a/tt_metal/api/tt-metalium/bfloat8.hpp b/tt_metal/api/tt-metalium/bfloat8.hpp index fb3f288d901..493da496cde 100644 --- a/tt_metal/api/tt-metalium/bfloat8.hpp +++ b/tt_metal/api/tt-metalium/bfloat8.hpp @@ -18,87 +18,6 @@ // TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor struct bfloat8_b {}; -template -inline uint8_t convert_u32_to_bfp8(uint32_t input, uint32_t shared_exp, bool is_exp_a) { - // check for both +/- 0.0 - constexpr uint32_t EXP_MANTISSA_BMSK = ((1U << 31) - 1); - bool is_zero = ((input & EXP_MANTISSA_BMSK) == 0); - - if (is_zero) { - return 0; - } - - uint32_t mantissa = input & 0x007fffff; - uint32_t exp = (input & 0x7f800000) >> 23; - uint32_t sign = (input & 0x80000000) >> 31; - - if (is_exp_a) { - int32_t se = static_cast(exp); - // rebias - se = se - 127 + 15; - // check for saturation - if (se > 31) { - se = 31; - mantissa = 0x007fffff; - } else if (se < 0) { - se = 0; - mantissa = 0x0; - } - - exp = static_cast(se); - } - - // float mantissa is 23 bits + hidden bit = 24 bits - // add hidden 1 - mantissa = (1 << 23) | mantissa; - - if (shared_exp >= exp) { - int exp_diff = shared_exp - exp; - // shift mantissa further down by exp diff - // In bit-shift operation (A >> B), the result is undefined if B is greater than or equal to the number of bits - // in A - while (exp_diff > 31) { - mantissa = mantissa >> 31; - exp_diff -= 31; - } - mantissa = mantissa >> exp_diff; - } - - // this needs to become 7 bits so shift 17 times - if (truncate_bfp_mantissa) { - // Truncation: Round down - mantissa = mantissa >> 17; - } else { - // Round mantissa to nearest even - mantissa += 1 << 16; - mantissa = mantissa >> 17; - if (mantissa > 127) { - mantissa = 127; - } - } - - // add sign bit only if result is not 0 - if (0 == mantissa) { - sign = 0; - } - mantissa = (sign << 7) | mantissa; - return mantissa; -} - -inline uint32_t create_packed_bfp8_packed_as_u32( - const std::vector& u32_vec, uint32_t shared_exp, bool is_exp_a) { - int nums_in_dword = 4; - uint32_t tmp_o = 0; - uint32_t mask = (1 << (32 / nums_in_dword)) - 1; - for (int i = nums_in_dword - 1; i >= 0; --i) // [0] in LSBs of dword - { - uint32_t conv_num = convert_u32_to_bfp8(u32_vec[i], shared_exp, is_exp_a); - tmp_o = tmp_o << (32 / nums_in_dword); - tmp_o = tmp_o | (conv_num & mask); - } - return tmp_o; -} - inline std::vector pack_fp32_vec_as_bfp8_tiles( tt::stl::Span fp32_vec, bool row_major_input, diff --git a/tt_metal/api/tt-metalium/blockfloat_common.hpp b/tt_metal/api/tt-metalium/blockfloat_common.hpp index c094c1dec96..c43eb88e304 100644 --- a/tt_metal/api/tt-metalium/blockfloat_common.hpp +++ b/tt_metal/api/tt-metalium/blockfloat_common.hpp @@ -147,9 +147,26 @@ inline uint8_t convert_u32_to_bfp(uint32_t input, uint32_t shared_exp, bool is_e // Truncation: Round down mantissa = mantissa >> MANTISSA_BFP_SHIFT; } else { - // Round mantissa to nearest even - mantissa += 1 << (MANTISSA_BFP_SHIFT - 1); + // Round mantissa to nearest; ties round to even + // Implementation of rounding process (example is for bfp8): + // - We want to round 23 bit mantissa to 6 bits with extra hidden bit + // - Mantissa is broken down to: <5> bits | guard bit | <17> bits of round value + // * If round value < 0x10000, round down (ie. mantissa is just <5> bits | guard bit) + // * If round value > 0x10000, round up (ie. add 1 to <5> bits | guard bit) + // * If round value = 0x10000, we have a tie and round to nearest even: + // ** If guard bit = 0, mantissa is even so round down + // ** If guard bit = 1, mantissa is odd so round up + constexpr uint32_t MANTISSA_ROUND_MASK = (1 << MANTISSA_BFP_SHIFT) - 1; + constexpr uint32_t TIE_VALUE = 1 << (MANTISSA_BFP_SHIFT - 1); + uint32_t round_value = mantissa & MANTISSA_ROUND_MASK; mantissa = mantissa >> MANTISSA_BFP_SHIFT; + uint32_t guard_bit = mantissa & 0x1; + + if (round_value > TIE_VALUE or (round_value == TIE_VALUE and guard_bit == 1)) { + // Round up + mantissa += 1; + } + if (mantissa > MANTISSA_BFP_MAX_VAL) { mantissa = MANTISSA_BFP_MAX_VAL; }