Skip to content

Commit

Permalink
#15889: Fix handling of mantissa rounding to respect ties round to even
Browse files Browse the repository at this point in the history
- Add tt_metal gtest to test convert_u32_to_bfp rounding logic for bfp8
- Remove unused functions in tt_metal/api/tt-metalium/bfloat8.hpp
  • Loading branch information
TT-BrianLiu committed Jan 22, 2025
1 parent a5cf197 commit 842747f
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 83 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ set(UNIT_TESTS_API_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_soc_descriptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_tilize_untilize.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_worker_config_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_blockfloat_common.cpp
)

# Define the function to create test executables for each architecture
Expand Down
84 changes: 84 additions & 0 deletions tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <tt-metalium/blockfloat_common.hpp>

namespace {

struct ConvertU32ToBfpParams {
float float_input;
uint32_t expected_mantissa;
float expected_float_output;
};

void roundtrip_test_for_mantissa_rounding_with_bfp8(
float float_input, uint8_t expected_mantissa, float expected_float_output) {
auto uint32_input = *reinterpret_cast<const uint32_t*>(&float_input);
// Set shared exponent as original float exponent (ie. skip logic for handling shared exponents)
auto shared_exp = uint32_input >> 23 & 0xFF;

auto output_mantissa = convert_u32_to_bfp<tt::DataFormat::Bfp8_b, false>(uint32_input, shared_exp, false);
EXPECT_EQ(output_mantissa, expected_mantissa);

uint32_t uint32_output = convert_bfp_to_u32(tt::DataFormat::Bfp8_b, output_mantissa, shared_exp, false);
float float_output = *reinterpret_cast<float*>(&uint32_output);
EXPECT_EQ(float_output, expected_float_output);
};

} // namespace

class ConvertU32ToBfpTests : public ::testing::TestWithParam<ConvertU32ToBfpParams> {};

TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithPositiveFloat) {
const auto& params = GetParam();
roundtrip_test_for_mantissa_rounding_with_bfp8(
params.float_input, params.expected_mantissa, params.expected_float_output);
}

TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithNegativeFloat) {
const auto& params = GetParam();
const auto float_input = -1 * params.float_input;
const auto expected_mantissa = params.expected_mantissa | 0x80;
const auto expected_float_output = -1 * params.expected_float_output;

roundtrip_test_for_mantissa_rounding_with_bfp8(float_input, expected_mantissa, expected_float_output);
}

INSTANTIATE_TEST_SUITE_P(
BlockfloatCommonTests,
ConvertU32ToBfpTests,
// clang-format off
// See tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp for explanation of rounding
// NOTE: These float values are cherry-picked such that:
// - The mantissa hits the 4 cases for rounding
// - The float values match behaviour of round(float) (assuming same spec of ties round to even)
::testing::Values(
// Round up always
ConvertU32ToBfpParams{
.float_input = 64.75, // Mantissa is 0x18000
.expected_mantissa = 0x41,
.expected_float_output = 65,
},
// Round down always
ConvertU32ToBfpParams{
.float_input = 65.25, // Mantissa is 0x28000
.expected_mantissa = 0x41,
.expected_float_output = 65,
},
// Tie: round down to nearest even
ConvertU32ToBfpParams{
.float_input = 64.5, // Mantissa is 0x10000
.expected_mantissa = 0x40,
.expected_float_output = 64,
},
// Tie: round up to nearest even
ConvertU32ToBfpParams{
.float_input = 65.5, // Mantissa is 0x30000
.expected_mantissa = 0x42,
.expected_float_output = 66,
}
) // Values
// clang-format on
);
81 changes: 0 additions & 81 deletions tt_metal/api/tt-metalium/bfloat8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,87 +18,6 @@
// TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor
struct bfloat8_b {};

template <bool truncate_bfp_mantissa = false>
inline uint8_t convert_u32_to_bfp8(uint32_t input, uint32_t shared_exp, bool is_exp_a) {
// check for both +/- 0.0
constexpr uint32_t EXP_MANTISSA_BMSK = ((1U << 31) - 1);
bool is_zero = ((input & EXP_MANTISSA_BMSK) == 0);

if (is_zero) {
return 0;
}

uint32_t mantissa = input & 0x007fffff;
uint32_t exp = (input & 0x7f800000) >> 23;
uint32_t sign = (input & 0x80000000) >> 31;

if (is_exp_a) {
int32_t se = static_cast<int32_t>(exp);
// rebias
se = se - 127 + 15;
// check for saturation
if (se > 31) {
se = 31;
mantissa = 0x007fffff;
} else if (se < 0) {
se = 0;
mantissa = 0x0;
}

exp = static_cast<uint32_t>(se);
}

// float mantissa is 23 bits + hidden bit = 24 bits
// add hidden 1
mantissa = (1 << 23) | mantissa;

if (shared_exp >= exp) {
int exp_diff = shared_exp - exp;
// shift mantissa further down by exp diff
// In bit-shift operation (A >> B), the result is undefined if B is greater than or equal to the number of bits
// in A
while (exp_diff > 31) {
mantissa = mantissa >> 31;
exp_diff -= 31;
}
mantissa = mantissa >> exp_diff;
}

// this needs to become 7 bits so shift 17 times
if (truncate_bfp_mantissa) {
// Truncation: Round down
mantissa = mantissa >> 17;
} else {
// Round mantissa to nearest even
mantissa += 1 << 16;
mantissa = mantissa >> 17;
if (mantissa > 127) {
mantissa = 127;
}
}

// add sign bit only if result is not 0
if (0 == mantissa) {
sign = 0;
}
mantissa = (sign << 7) | mantissa;
return mantissa;
}

inline uint32_t create_packed_bfp8_packed_as_u32(
const std::vector<uint32_t>& u32_vec, uint32_t shared_exp, bool is_exp_a) {
int nums_in_dword = 4;
uint32_t tmp_o = 0;
uint32_t mask = (1 << (32 / nums_in_dword)) - 1;
for (int i = nums_in_dword - 1; i >= 0; --i) // [0] in LSBs of dword
{
uint32_t conv_num = convert_u32_to_bfp8(u32_vec[i], shared_exp, is_exp_a);
tmp_o = tmp_o << (32 / nums_in_dword);
tmp_o = tmp_o | (conv_num & mask);
}
return tmp_o;
}

inline std::vector<uint32_t> pack_fp32_vec_as_bfp8_tiles(
tt::stl::Span<const float> fp32_vec,
bool row_major_input,
Expand Down
21 changes: 19 additions & 2 deletions tt_metal/api/tt-metalium/blockfloat_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,26 @@ inline uint8_t convert_u32_to_bfp(uint32_t input, uint32_t shared_exp, bool is_e
// Truncation: Round down
mantissa = mantissa >> MANTISSA_BFP_SHIFT;
} else {
// Round mantissa to nearest even
mantissa += 1 << (MANTISSA_BFP_SHIFT - 1);
// Round mantissa to nearest; ties round to even
// Implementation of rounding process (example is for bfp8):
// - We want to round 23 bit mantissa to 6 bits with extra hidden bit
// - Mantissa is broken down to: <5> bits | guard bit | <17> bits of round value
// * If round value < 0x10000, round down (ie. mantissa is just <5> bits | guard bit)
// * If round value > 0x10000, round up (ie. add 1 to <5> bits | guard bit)
// * If round value = 0x10000, we have a tie and round to nearest even:
// ** If guard bit = 0, mantissa is even so round down
// ** If guard bit = 1, mantissa is odd so round up
constexpr uint32_t MANTISSA_ROUND_MASK = (1 << MANTISSA_BFP_SHIFT) - 1;
constexpr uint32_t TIE_VALUE = 1 << (MANTISSA_BFP_SHIFT - 1);
uint32_t round_value = mantissa & MANTISSA_ROUND_MASK;
mantissa = mantissa >> MANTISSA_BFP_SHIFT;
uint32_t guard_bit = mantissa & 0x1;

if (round_value > TIE_VALUE or (round_value == TIE_VALUE and guard_bit == 1)) {
// Round up
mantissa += 1;
}

if (mantissa > MANTISSA_BFP_MAX_VAL) {
mantissa = MANTISSA_BFP_MAX_VAL;
}
Expand Down

0 comments on commit 842747f

Please sign in to comment.