From 5d8ca3a6a1768877c3890ff7aa5dc7d6c8f250d2 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Wed, 22 Jan 2025 20:35:52 +0000 Subject: [PATCH] #15889: Fix handling of mantissa rounding to respect ties round to even - Add tt_metal gtest to test convert_u32_to_bfp rounding logic for bfp8 - Remove unused functions in tt_metal/api/tt-metalium/bfloat8.hpp - Update impacted tests/models to reflect minor differences * Minor pcc adjustment in op tests * Skip typecast for bfp16 to bfp8 (issue 17237) * Update expected demo output for Falcon40B demo test ** This will break CI which uses cached bfp8 weights ** Model weights need to be re-generated and cached --- .../falcon40b/demo/expected_output_data.json | 2 +- tests/tt_metal/tt_metal/api/CMakeLists.txt | 1 + .../tt_metal/api/test_blockfloat_common.cpp | 84 +++++++++++++++++++ .../squeezebert/test_ttnn_squeezebert.py | 2 +- .../vit/test_ttnn_optimized_sharded_vit.py | 2 +- .../eltwise/test_eltwise_typecast.py | 3 + .../test_distributed_layernorm_sharded.py | 4 +- .../unit_tests/operations/test_maxpool2d.py | 2 +- tt_metal/api/tt-metalium/bfloat8.hpp | 81 ------------------ .../api/tt-metalium/blockfloat_common.hpp | 23 ++++- 10 files changed, 114 insertions(+), 90 deletions(-) create mode 100644 tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp diff --git a/models/demos/t3000/falcon40b/demo/expected_output_data.json b/models/demos/t3000/falcon40b/demo/expected_output_data.json index 380fb5eeff1a..5f644c1f8af3 100644 --- a/models/demos/t3000/falcon40b/demo/expected_output_data.json +++ b/models/demos/t3000/falcon40b/demo/expected_output_data.json @@ -1 +1 @@ -["List the first 5 prime numbers \nThe first 5 prime numbers are 2, 3, 5, 7, and 11. ", "Give a brief history of the internet \nThe internet was invented in the late 1960s by computer scientists at the University of California, Los Angeles (UCLA). It was originally called ARPANET and was designed to allow scientists to share information and resources across different computer networks. In the 1990s, the internet became more widely available to the public and began to transform the way people communicate and access information. Today, the internet is a ubiquitous part of modern life, with billions of people using it daily for everything from shopping to social media to streaming entertainment. ", "Describe to me some good coding practices \nSome good coding practices include: \n\n1. Properly commenting code to make it easier to understand and maintain. \n2. Using consistent naming conventions for variables and functions. \n3. Writing clean and readable code that is easy to debug. \n4. Avoiding unnecessary complexity and keeping code simple and concise. \n5. Using version control to track changes and revert mistakes. \n6. Testing code thoroughly before deploying it. \n7. Keeping up-to-date with industry standards and best practices. \n8. Collaborating with other developers to improve code", "write a short poem about Paris in English\nParis is a city of love and romance,\nWhere the streets are filled with art and culture,\nThe Eiffel Tower stands tall and proud,\nAnd the Seine River flows through the heart of the city,\nParis is a city of dreams and possibilities,\nWhere the people are friendly and welcoming,\nThe cafes and restaurants are filled with delicious food,\nAnd the museums and galleries are filled with treasures,\nParis is a city of beauty and charm,\nWhere the architecture is stunning and the parks are lush,\nThe city is alive with energy and excitement,\nAnd the people", "Who is the inventor of the telephone?\nAlexander Graham Bell is credited with inventing the telephone in 1876. ", "write a short poem about Istanbul in English\nIstanbul is a city of contrasts,\nWhere East meets West,\nWhere ancient meets modern,\nWhere old meets new,\nWhere past meets present,\nWhere history meets future,\nWhere tradition meets innovation,\nWhere culture meets commerce,\nWhere religion meets secularism,\nWhere art meets architecture,\nWhere beauty meets chaos,\nWhere diversity meets unity,\nWhere the old city meets the new city,\nWhere the past meets the future,\nWhere the East meets the West,\nWhere the old meets the new,\nWhere the ancient meets the modern,\nWhere the", "What are the tourist attractions in Paris?\nParis is home to many famous landmarks and attractions such as the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, the Champs-\u00c9lys\u00e9es, the Palace of Versailles, and the Seine River. Other popular attractions include the Montmartre district, the Arc de Triomphe, and the Parisian parks such as Jardin des Tuileries and Parc de la Villette. ", "How many countries are in Africa? \nThere are 54 countries in Africa. ", "what is the capital of USA? \nThe capital of USA is Washington D.C. ", "what is the capital of Canada? \nThe capital of Canada is Ottawa. ", "what is the capital of UK? \nThe capital of UK is London. ", "what is the capital of Germany? \nThe capital of Germany is Berlin. ", "what is the capital of France? \nThe capital of France is Paris. ", "what is the capital of Japan? \nThe capital of Japan is Tokyo. ", "what is the capital of India? \nThe capital of India is New Delhi. ", "what is the capital of China? \nThe capital of China is Beijing. ", "what is the currency of Cuba? \nThe currency of Cuba is the Cuban peso (CUP). ", "what is the currency of Lebanon? \nThe currency of Lebanon is the Lebanese pound (LBP). ", "what is the currency of Brazil? \nThe currency of Brazil is the Brazilian Real (BRL). ", "what is the currency of Australia? \nThe currency of Australia is the Australian dollar (AUD). ", "what is the currency of Jamaica? \nThe currency of Jamaica is the Jamaican dollar. ", "what is the currency of Egypt? \nThe currency of Egypt is the Egyptian pound (EGP). ", "what is the currency of Uzbekistan? \nThe currency of Uzbekistan is the Uzbekistani som (UZS). ", "what is the currency of Argentina? \nThe currency of Argentina is the Argentine peso. ", "describe the geographic location of London in UK\nLondon is located in the southeast of England, on the River Thames. It is the capital city of the United Kingdom and the largest city in Europe. ", "describe the geographic location of Toronto in Canada\nToronto is located in the province of Ontario, Canada. It is situated on the northwestern shore of Lake Ontario, and is the largest city in Canada. Toronto is also the fourth largest city in North America, with a population of over 2.8 million people. ", "describe the geographic location of Madrid in Spain\nMadrid is located in the center of Spain, in the region of Madrid. It is the capital city of Spain and the largest city in the country. Madrid is situated on a plateau at an elevation of 2,180 feet (660 meters) above sea level. ", "describe the geographic location of Paris in France\nParis is located in the north-central part of France, on the Seine River. It is the capital city of France and the largest city in the country. ", "describe the geographic location of Rome in Italy\nRome is located in central Italy, on the Tiber River. It is the capital city of Italy and the largest city in the country. ", "describe the geographic location of Istanbul in Turkey\nIstanbul is located in Turkey, on the Bosphorus Strait, which connects the Black Sea to the Sea of Marmara. It is the largest city in Turkey and the fifth largest city in the world. ", "describe the geographic location of Shanghai in China\nShanghai is located in eastern China, on the Yangtze River Delta. It is the largest city in China and one of the largest cities in the world. ", "describe the geographic location of Lagos in Nigeria\nLagos is located in the southwestern part of Nigeria, on the Gulf of Guinea. It is the largest city in Nigeria and the second largest city in Africa. Lagos is also the economic and cultural center of Nigeria, with a population of over 20 million people. "] +["List the first 5 prime numbers \nThe first 5 prime numbers are 2, 3, 5, 7, and 11. ", "Give a brief history of the internet \nThe internet was invented in the 1960s as a way to connect computers together. It was originally called ARPANET and was developed by the United States Department of Defense. In the 1990s, the internet became more widely available to the public and began to grow rapidly. Today, it is a global network of computers and servers that allows people to communicate, share information, and access a vast array of resources and services. ", "Describe to me some good coding practices \nSome good coding practices include: \n\n1. Properly commenting code to make it easier to understand and maintain. \n2. Using descriptive variable names to make code easier to read and debug. \n3. Writing clean and organized code that follows a consistent style. \n4. Testing code thoroughly to ensure it works as intended. \n5. Using version control to track changes and revert mistakes. \n6. Avoiding unnecessary complexity and keeping code simple and concise. \n7. Using best practices for coding standards and code quality. \n8. Continuously learning and", "write a short poem about Paris in English\nParis is a city of love and romance,\nWhere the streets are lined with trees and flowers,\nThe air is filled with the scent of fresh pastries,\nAnd the sound of laughter and chatter fills the air.\nThe city is alive with energy and excitement,\nAnd the beauty of the city is unmatched.\nParis is a city that will always hold a special place in my heart. ", "Who is the inventor of the telephone?\nAlexander Graham Bell is credited with inventing the telephone in 1876. ", "write a short poem about Istanbul in English\nIstanbul is a city of contrasts,\nWhere East meets West,\nWhere ancient meets modern,\nWhere old meets new,\nWhere past meets present,\nWhere history meets future,\nWhere tradition meets innovation,\nWhere culture meets commerce,\nWhere religion meets secularism,\nWhere art meets architecture,\nWhere beauty meets chaos,\nWhere diversity meets unity,\nWhere the old city meets the new city,\nWhere the past meets the present,\nWhere the future meets the past,\nWhere the old meets the new,\nWhere the East meets the West,\nWhere the", "What are the tourist attractions in Paris?\nParis is home to many famous landmarks and attractions such as the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, the Champs-\u00c9lys\u00e9es, the Palace of Versailles, and the Seine River. Other popular attractions include the Montmartre district, the Arc de Triomphe, and the Parisian parks such as Jardin des Tuileries and Parc de la Villette. ", "How many countries are in Africa? \nThere are 54 countries in Africa. ", "what is the capital of USA? \nThe capital of USA is Washington D.C. ", "what is the capital of Canada? \nThe capital of Canada is Ottawa. ", "what is the capital of UK? \nThe capital of UK is London. ", "what is the capital of Germany? \nThe capital of Germany is Berlin. ", "what is the capital of France? \nThe capital of France is Paris. ", "what is the capital of Japan? \nThe capital of Japan is Tokyo. ", "what is the capital of India? \nThe capital of India is New Delhi. ", "what is the capital of China? \nThe capital of China is Beijing. ", "what is the currency of Cuba? \nThe currency of Cuba is the Cuban peso (CUP). ", "what is the currency of Lebanon? \nThe currency of Lebanon is the Lebanese pound (LBP). ", "what is the currency of Brazil? \nThe currency of Brazil is the Brazilian Real (BRL). ", "what is the currency of Australia? \nThe currency of Australia is the Australian dollar (AUD). ", "what is the currency of Jamaica? \nThe currency of Jamaica is the Jamaican dollar (JMD). ", "what is the currency of Egypt? \nThe currency of Egypt is the Egyptian pound (EGP). ", "what is the currency of Uzbekistan? \nThe currency of Uzbekistan is the Uzbekistani som (UZS). ", "what is the currency of Argentina? \nThe currency of Argentina is the Argentine peso. ", "describe the geographic location of London in UK\nLondon is located in the southeast of England, on the River Thames. It is the capital city of the United Kingdom and the largest city in Europe. ", "describe the geographic location of Toronto in Canada\nToronto is located in the province of Ontario, in the southern part of Canada. It is situated on the northwestern shore of Lake Ontario, and is the largest city in Canada with a population of over 2.8 million people. ", "describe the geographic location of Madrid in Spain\nMadrid is located in the center of Spain, in the region of Madrid. It is the capital city of Spain and the largest city in the country. Madrid is situated on a plateau at an elevation of 2,180 feet (660 meters) above sea level. ", "describe the geographic location of Paris in France\nParis is located in the north-central part of France, on the River Seine. It is the capital city of France and the largest city in the country. ", "describe the geographic location of Rome in Italy\nRome is located in central Italy, on the Tiber River. It is the capital city of Italy and the largest city in the country. ", "describe the geographic location of Istanbul in Turkey\nIstanbul is located in the northwest corner of Turkey, on the Bosphorus Strait, which connects the Black Sea to the Sea of Marmara. It is the largest city in Turkey and the fifth largest city in the world. ", "describe the geographic location of Shanghai in China\nShanghai is located in eastern China, on the Yangtze River Delta. It is the largest city in China and one of the largest cities in the world. ", "describe the geographic location of Lagos in Nigeria\nLagos is located in the southwestern part of Nigeria, on the Gulf of Guinea. It is the largest city in Nigeria and the fifth largest city in Africa. Lagos is also the economic and cultural hub of Nigeria, with a population of over 20 million people. "] diff --git a/tests/tt_metal/tt_metal/api/CMakeLists.txt b/tests/tt_metal/tt_metal/api/CMakeLists.txt index 187382022b4f..4a15c487bace 100644 --- a/tests/tt_metal/tt_metal/api/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/api/CMakeLists.txt @@ -33,6 +33,7 @@ set(UNIT_TESTS_API_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_soc_descriptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_tilize_untilize.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_worker_config_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_blockfloat_common.cpp ) # Define the function to create test executables for each architecture diff --git a/tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp b/tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp new file mode 100644 index 000000000000..ac2193fedb5e --- /dev/null +++ b/tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +namespace { + +struct ConvertU32ToBfpParams { + float float_input = 0; + uint32_t expected_mantissa = 0; + float expected_float_output = 0; +}; + +void roundtrip_test_for_mantissa_rounding_with_bfp8( + float float_input, uint8_t expected_mantissa, float expected_float_output) { + auto uint32_input = *reinterpret_cast(&float_input); + // Set shared exponent as original float exponent (ie. skip logic for handling shared exponents) + auto shared_exp = uint32_input >> 23 & 0xFF; + + auto output_mantissa = convert_u32_to_bfp(uint32_input, shared_exp, false); + EXPECT_EQ(output_mantissa, expected_mantissa); + + uint32_t uint32_output = convert_bfp_to_u32(tt::DataFormat::Bfp8_b, output_mantissa, shared_exp, false); + float float_output = *reinterpret_cast(&uint32_output); + EXPECT_EQ(float_output, expected_float_output); +}; + +} // namespace + +class ConvertU32ToBfpTests : public ::testing::TestWithParam {}; + +TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithPositiveFloat) { + const auto& params = GetParam(); + roundtrip_test_for_mantissa_rounding_with_bfp8( + params.float_input, params.expected_mantissa, params.expected_float_output); +} + +TEST_P(ConvertU32ToBfpTests, MantissaRoundingWithNegativeFloat) { + const auto& params = GetParam(); + const auto float_input = -1 * params.float_input; + const auto expected_mantissa = params.expected_mantissa | 0x80; + const auto expected_float_output = -1 * params.expected_float_output; + + roundtrip_test_for_mantissa_rounding_with_bfp8(float_input, expected_mantissa, expected_float_output); +} + +INSTANTIATE_TEST_SUITE_P( + BlockfloatCommonTests, + ConvertU32ToBfpTests, + // clang-format off + // See tests/tt_metal/tt_metal/api/test_blockfloat_common.cpp for explanation of rounding + // NOTE: These float values are cherry-picked such that: + // - The mantissa hits the 4 cases for rounding + // - The float values match behaviour of round(float) (assuming same spec of ties round to even) + ::testing::Values( + // Round up always + ConvertU32ToBfpParams{ + .float_input = 64.75, // Mantissa is 0x18000 + .expected_mantissa = 0x41, + .expected_float_output = 65, + }, + // Round down always + ConvertU32ToBfpParams{ + .float_input = 65.25, // Mantissa is 0x28000 + .expected_mantissa = 0x41, + .expected_float_output = 65, + }, + // Tie: round down to nearest even + ConvertU32ToBfpParams{ + .float_input = 64.5, // Mantissa is 0x10000 + .expected_mantissa = 0x40, + .expected_float_output = 64, + }, + // Tie: round up to nearest even + ConvertU32ToBfpParams{ + .float_input = 65.5, // Mantissa is 0x30000 + .expected_mantissa = 0x42, + .expected_float_output = 66, + } + ) // Values + // clang-format on +); diff --git a/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py b/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py index c70b3bed5cc7..f2ecb77304f2 100644 --- a/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py +++ b/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py @@ -358,4 +358,4 @@ def test_squeezebert_for_question_answering(device, model_name, batch_size, sequ tt_end_logits = tt_output[..., :, 1] assert_with_pcc(torch_output.start_logits, tt_start_logits, 0.83 if is_grayskull() else 0.88) - assert_with_pcc(torch_output.end_logits, tt_end_logits, 0.85 if is_grayskull() else 0.93) + assert_with_pcc(torch_output.end_logits, tt_end_logits, 0.84 if is_grayskull() else 0.93) diff --git a/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py index 209f7af747dd..a5d01530e5dc 100644 --- a/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py +++ b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py @@ -533,4 +533,4 @@ def test_vit(device, model_name, batch_size, image_size, image_channels, sequenc ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output[0, 0, :1000], 0.8146) + assert_with_pcc(torch_output, output[0, 0, :1000], 0.8139) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py b/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py index 10938b651523..c85c0185d0b3 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py @@ -124,6 +124,7 @@ def test_run_eltwise_typecast_op( ) +@pytest.mark.skip("Issue #17237: Does not work with new mantissa rounding") @skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0") def test_typecast_bf16_to_bfp8_b(device): torch.manual_seed(0) @@ -155,6 +156,7 @@ def print_mismatches(cpu, npu, num_max_print): break +@pytest.mark.skip("Issue #17237: Does not work with new mantissa rounding") @pytest.mark.parametrize("seed", [0, 2, 4, 6, 8]) @pytest.mark.parametrize("scale", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("bias", [0, 1, 2, 4, 8, 16, 32, 64, 128]) @@ -186,6 +188,7 @@ def test_typecast_bf16_to_bfp8_b_various_input(seed, scale, bias, device): assert passed +@pytest.mark.skip("Issue #17237: Does not work with new mantissa rounding") @pytest.mark.parametrize("seed", [0]) @pytest.mark.parametrize("scale", [4]) @pytest.mark.parametrize("bias", [2]) diff --git a/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py b/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py index 681444be1946..d90aafba85a8 100644 --- a/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py +++ b/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py @@ -224,7 +224,7 @@ def run_pre_allgather_layernorm( @pytest.mark.parametrize( "min_pcc_ex2", [ - 0.983, + 0.982, ], ) @pytest.mark.parametrize(("fuse_residual", "max_atol_ex2"), [(False, 0.04), (True, 0.09)]) @@ -275,7 +275,7 @@ def test_pre_allgather_layernorm( @pytest.mark.parametrize(("mean", "std"), ([0, 1],)) @pytest.mark.parametrize("core_grid", ((4, 1),)) @pytest.mark.parametrize(("min_pcc_ex", "max_atol_ex"), [(0.9997, 0.01)]) -@pytest.mark.parametrize(("min_pcc_ex2", "max_atol_ex2"), [(0.987, 0.04)]) +@pytest.mark.parametrize(("min_pcc_ex2", "max_atol_ex2"), [(0.986, 0.04)]) def test_pre_allgather_layernorm_1d_reduce( device, use_program_cache, diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 38ce1cd7e59b..279bdc724142 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -216,7 +216,7 @@ def run_max_pool( pcc_thresh = 1.0 if dtype == ttnn.bfloat8_b: - pcc_thresh = 0.9997 + pcc_thresh = 0.9994 passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch, pcc_thresh) diff --git a/tt_metal/api/tt-metalium/bfloat8.hpp b/tt_metal/api/tt-metalium/bfloat8.hpp index fb3f288d9017..493da496cde5 100644 --- a/tt_metal/api/tt-metalium/bfloat8.hpp +++ b/tt_metal/api/tt-metalium/bfloat8.hpp @@ -18,87 +18,6 @@ // TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor struct bfloat8_b {}; -template -inline uint8_t convert_u32_to_bfp8(uint32_t input, uint32_t shared_exp, bool is_exp_a) { - // check for both +/- 0.0 - constexpr uint32_t EXP_MANTISSA_BMSK = ((1U << 31) - 1); - bool is_zero = ((input & EXP_MANTISSA_BMSK) == 0); - - if (is_zero) { - return 0; - } - - uint32_t mantissa = input & 0x007fffff; - uint32_t exp = (input & 0x7f800000) >> 23; - uint32_t sign = (input & 0x80000000) >> 31; - - if (is_exp_a) { - int32_t se = static_cast(exp); - // rebias - se = se - 127 + 15; - // check for saturation - if (se > 31) { - se = 31; - mantissa = 0x007fffff; - } else if (se < 0) { - se = 0; - mantissa = 0x0; - } - - exp = static_cast(se); - } - - // float mantissa is 23 bits + hidden bit = 24 bits - // add hidden 1 - mantissa = (1 << 23) | mantissa; - - if (shared_exp >= exp) { - int exp_diff = shared_exp - exp; - // shift mantissa further down by exp diff - // In bit-shift operation (A >> B), the result is undefined if B is greater than or equal to the number of bits - // in A - while (exp_diff > 31) { - mantissa = mantissa >> 31; - exp_diff -= 31; - } - mantissa = mantissa >> exp_diff; - } - - // this needs to become 7 bits so shift 17 times - if (truncate_bfp_mantissa) { - // Truncation: Round down - mantissa = mantissa >> 17; - } else { - // Round mantissa to nearest even - mantissa += 1 << 16; - mantissa = mantissa >> 17; - if (mantissa > 127) { - mantissa = 127; - } - } - - // add sign bit only if result is not 0 - if (0 == mantissa) { - sign = 0; - } - mantissa = (sign << 7) | mantissa; - return mantissa; -} - -inline uint32_t create_packed_bfp8_packed_as_u32( - const std::vector& u32_vec, uint32_t shared_exp, bool is_exp_a) { - int nums_in_dword = 4; - uint32_t tmp_o = 0; - uint32_t mask = (1 << (32 / nums_in_dword)) - 1; - for (int i = nums_in_dword - 1; i >= 0; --i) // [0] in LSBs of dword - { - uint32_t conv_num = convert_u32_to_bfp8(u32_vec[i], shared_exp, is_exp_a); - tmp_o = tmp_o << (32 / nums_in_dword); - tmp_o = tmp_o | (conv_num & mask); - } - return tmp_o; -} - inline std::vector pack_fp32_vec_as_bfp8_tiles( tt::stl::Span fp32_vec, bool row_major_input, diff --git a/tt_metal/api/tt-metalium/blockfloat_common.hpp b/tt_metal/api/tt-metalium/blockfloat_common.hpp index c094c1dec962..e98718e58897 100644 --- a/tt_metal/api/tt-metalium/blockfloat_common.hpp +++ b/tt_metal/api/tt-metalium/blockfloat_common.hpp @@ -130,7 +130,7 @@ inline uint8_t convert_u32_to_bfp(uint32_t input, uint32_t shared_exp, bool is_e // add hidden 1 mantissa = (1 << 23) | mantissa; - if (shared_exp >= exp) { + if (shared_exp > exp) { int exp_diff = shared_exp - exp; // shift mantissa further down by exp diff // In bit-shift operation (A >> B), the result is undefined if B is greater than or equal to the number of bits @@ -147,9 +147,26 @@ inline uint8_t convert_u32_to_bfp(uint32_t input, uint32_t shared_exp, bool is_e // Truncation: Round down mantissa = mantissa >> MANTISSA_BFP_SHIFT; } else { - // Round mantissa to nearest even - mantissa += 1 << (MANTISSA_BFP_SHIFT - 1); + // Round mantissa to nearest; ties round to even + // Implementation of rounding process (example is for bfp8): + // - We want to round 23 bit mantissa to 6 bits with extra hidden bit + // - Mantissa is broken down to: <5> bits | guard bit | <17> bits of round value + // * If round value < 0x10000, round down (ie. mantissa is just <5> bits | guard bit) + // * If round value > 0x10000, round up (ie. add 1 to <5> bits | guard bit) + // * If round value = 0x10000, we have a tie and round to nearest even: + // ** If guard bit = 0, mantissa is even so round down + // ** If guard bit = 1, mantissa is odd so round up + constexpr uint32_t MANTISSA_ROUND_MASK = (1 << MANTISSA_BFP_SHIFT) - 1; + constexpr uint32_t TIE_VALUE = 1 << (MANTISSA_BFP_SHIFT - 1); + uint32_t round_value = mantissa & MANTISSA_ROUND_MASK; mantissa = mantissa >> MANTISSA_BFP_SHIFT; + uint32_t guard_bit = mantissa & 0x1; + + if (round_value > TIE_VALUE or (round_value == TIE_VALUE and guard_bit == 1)) { + // Round up + mantissa += 1; + } + if (mantissa > MANTISSA_BFP_MAX_VAL) { mantissa = MANTISSA_BFP_MAX_VAL; }