Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15889: Fix handling of mantissa rounding to respect ties round to even #16997

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ set(UNIT_TESTS_API_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_soc_descriptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_tilize_untilize.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_worker_config_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_blockfloat_common.cpp
)

# Define the function to create test executables for each architecture
Expand Down
84 changes: 84 additions & 0 deletions tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <tt-metalium/blockfloat_common.hpp>

namespace {

struct ConvertU32ToBfpParams {
float float_input;
uint32_t expected_mantissa;
float expected_float_output;
Comment on lines +11 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please init

};

void roundtrip_test_for_mantissa_rounding_with_bfp8(
float float_input, uint8_t expected_mantissa, float expected_float_output) {
auto uint32_input = *reinterpret_cast<const uint32_t*>(&float_input);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we allow conversion functions to accept/return or float or uint32t?

// Set shared exponent as original float exponent (ie. skip logic for handling shared exponents)
auto shared_exp = uint32_input >> 23 & 0xFF;

auto output_mantissa = convert_u32_to_bfp<tt::DataFormat::Bfp8_b, false>(uint32_input, shared_exp, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what false means here, in both cases?

EXPECT_EQ(output_mantissa, expected_mantissa);

uint32_t uint32_output = convert_bfp_to_u32(tt::DataFormat::Bfp8_b, output_mantissa, shared_exp, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder why one function is templated and the other accepts bfp type as argument

float float_output = *reinterpret_cast<float*>(&uint32_output);
EXPECT_EQ(float_output, expected_float_output);
};

} // namespace

class ConvertU32ToBfpTests : public ::testing::TestWithParam<ConvertU32ToBfpParams> {};

TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithPositiveFloat) {
const auto& params = GetParam();
roundtrip_test_for_mantissa_rounding_with_bfp8(
params.float_input, params.expected_mantissa, params.expected_float_output);
}

TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithNegativeFloat) {
const auto& params = GetParam();
const auto float_input = -1 * params.float_input;
const auto expected_mantissa = params.expected_mantissa | 0x80;
const auto expected_float_output = -1 * params.expected_float_output;

roundtrip_test_for_mantissa_rounding_with_bfp8(float_input, expected_mantissa, expected_float_output);
}

INSTANTIATE_TEST_SUITE_P(
BlockfloatCommonTests,
ConvertU32ToBfpTests,
// clang-format off
// See tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp for explanation of rounding
// NOTE: These float values are cherry-picked such that:
// - The mantissa hits the 4 cases for rounding
// - The float values match behaviour of round(float) (assuming same spec of ties round to even)
::testing::Values(
// Round up always
ConvertU32ToBfpParams{
.float_input = 64.75, // Mantissa is 0x18000
.expected_mantissa = 0x41,
.expected_float_output = 65,
},
// Round down always
ConvertU32ToBfpParams{
.float_input = 65.25, // Mantissa is 0x28000
.expected_mantissa = 0x41,
.expected_float_output = 65,
},
// Tie: round down to nearest even
ConvertU32ToBfpParams{
.float_input = 64.5, // Mantissa is 0x10000
.expected_mantissa = 0x40,
.expected_float_output = 64,
},
// Tie: round up to nearest even
ConvertU32ToBfpParams{
.float_input = 65.5, // Mantissa is 0x30000
.expected_mantissa = 0x42,
.expected_float_output = 66,
}
) // Values
// clang-format on
);
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def run_pre_allgather_layernorm(
@pytest.mark.parametrize(
"min_pcc_ex2",
[
0.983,
0.982,
],
)
@pytest.mark.parametrize(("fuse_residual", "max_atol_ex2"), [(False, 0.04), (True, 0.09)])
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_pre_allgather_layernorm(
@pytest.mark.parametrize(("mean", "std"), ([0, 1],))
@pytest.mark.parametrize("core_grid", ((4, 1),))
@pytest.mark.parametrize(("min_pcc_ex", "max_atol_ex"), [(0.9997, 0.01)])
@pytest.mark.parametrize(("min_pcc_ex2", "max_atol_ex2"), [(0.987, 0.04)])
@pytest.mark.parametrize(("min_pcc_ex2", "max_atol_ex2"), [(0.986, 0.04)])
def test_pre_allgather_layernorm_1d_reduce(
device,
use_program_cache,
Expand Down
81 changes: 0 additions & 81 deletions tt_metal/api/tt-metalium/bfloat8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,87 +18,6 @@
// TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor
struct bfloat8_b {};

template <bool truncate_bfp_mantissa = false>
inline uint8_t convert_u32_to_bfp8(uint32_t input, uint32_t shared_exp, bool is_exp_a) {
// check for both +/- 0.0
constexpr uint32_t EXP_MANTISSA_BMSK = ((1U << 31) - 1);
bool is_zero = ((input & EXP_MANTISSA_BMSK) == 0);

if (is_zero) {
return 0;
}

uint32_t mantissa = input & 0x007fffff;
uint32_t exp = (input & 0x7f800000) >> 23;
uint32_t sign = (input & 0x80000000) >> 31;

if (is_exp_a) {
int32_t se = static_cast<int32_t>(exp);
// rebias
se = se - 127 + 15;
// check for saturation
if (se > 31) {
se = 31;
mantissa = 0x007fffff;
} else if (se < 0) {
se = 0;
mantissa = 0x0;
}

exp = static_cast<uint32_t>(se);
}

// float mantissa is 23 bits + hidden bit = 24 bits
// add hidden 1
mantissa = (1 << 23) | mantissa;

if (shared_exp >= exp) {
int exp_diff = shared_exp - exp;
// shift mantissa further down by exp diff
// In bit-shift operation (A >> B), the result is undefined if B is greater than or equal to the number of bits
// in A
while (exp_diff > 31) {
mantissa = mantissa >> 31;
exp_diff -= 31;
}
mantissa = mantissa >> exp_diff;
}

// this needs to become 7 bits so shift 17 times
if (truncate_bfp_mantissa) {
// Truncation: Round down
mantissa = mantissa >> 17;
} else {
// Round mantissa to nearest even
mantissa += 1 << 16;
mantissa = mantissa >> 17;
if (mantissa > 127) {
mantissa = 127;
}
}

// add sign bit only if result is not 0
if (0 == mantissa) {
sign = 0;
}
mantissa = (sign << 7) | mantissa;
return mantissa;
}

inline uint32_t create_packed_bfp8_packed_as_u32(
const std::vector<uint32_t>& u32_vec, uint32_t shared_exp, bool is_exp_a) {
int nums_in_dword = 4;
uint32_t tmp_o = 0;
uint32_t mask = (1 << (32 / nums_in_dword)) - 1;
for (int i = nums_in_dword - 1; i >= 0; --i) // [0] in LSBs of dword
{
uint32_t conv_num = convert_u32_to_bfp8(u32_vec[i], shared_exp, is_exp_a);
tmp_o = tmp_o << (32 / nums_in_dword);
tmp_o = tmp_o | (conv_num & mask);
}
return tmp_o;
}

inline std::vector<uint32_t> pack_fp32_vec_as_bfp8_tiles(
tt::stl::Span<const float> fp32_vec,
bool row_major_input,
Expand Down
23 changes: 20 additions & 3 deletions tt_metal/api/tt-metalium/blockfloat_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ inline uint8_t convert_u32_to_bfp(uint32_t input, uint32_t shared_exp, bool is_e
// add hidden 1
mantissa = (1 << 23) | mantissa;

if (shared_exp >= exp) {
if (shared_exp > exp) {
int exp_diff = shared_exp - exp;
// shift mantissa further down by exp diff
// In bit-shift operation (A >> B), the result is undefined if B is greater than or equal to the number of bits
Expand All @@ -147,9 +147,26 @@ inline uint8_t convert_u32_to_bfp(uint32_t input, uint32_t shared_exp, bool is_e
// Truncation: Round down
mantissa = mantissa >> MANTISSA_BFP_SHIFT;
} else {
// Round mantissa to nearest even
mantissa += 1 << (MANTISSA_BFP_SHIFT - 1);
// Round mantissa to nearest; ties round to even
// Implementation of rounding process (example is for bfp8):
// - We want to round 23 bit mantissa to 6 bits with extra hidden bit
// - Mantissa is broken down to: <5> bits | guard bit | <17> bits of round value
// * If round value < 0x10000, round down (ie. mantissa is just <5> bits | guard bit)
// * If round value > 0x10000, round up (ie. add 1 to <5> bits | guard bit)
// * If round value = 0x10000, we have a tie and round to nearest even:
// ** If guard bit = 0, mantissa is even so round down
// ** If guard bit = 1, mantissa is odd so round up
constexpr uint32_t MANTISSA_ROUND_MASK = (1 << MANTISSA_BFP_SHIFT) - 1;
constexpr uint32_t TIE_VALUE = 1 << (MANTISSA_BFP_SHIFT - 1);
uint32_t round_value = mantissa & MANTISSA_ROUND_MASK;
mantissa = mantissa >> MANTISSA_BFP_SHIFT;
uint32_t guard_bit = mantissa & 0x1;

if (round_value > TIE_VALUE or (round_value == TIE_VALUE and guard_bit == 1)) {
// Round up
mantissa += 1;
}

if (mantissa > MANTISSA_BFP_MAX_VAL) {
mantissa = MANTISSA_BFP_MAX_VAL;
}
Expand Down
Loading