From 6dd526f70450ea14c6ee538aa4b4fbe7ff228ce3 Mon Sep 17 00:00:00 2001 From: Nikola Cvetkovic Date: Mon, 9 Sep 2024 14:43:30 +0000 Subject: [PATCH] #10226: Added or refactored BH SFPU functions 1. SFPU mask - added _int version 2. SFPU reciprocal - modified to use submodule LLKs 3. SFPU sqrt - modified to use submodule LLKs 4. SFPU trunc OP - enabled for BH --- .../llk_api/llk_sfpu/ckernel_sfpu_mask.h | 18 +++++- .../llk_api/llk_sfpu/ckernel_sfpu_recip.h | 59 ++----------------- .../llk_api/llk_sfpu/ckernel_sfpu_sqrt.h | 42 +------------ .../llk_math_eltwise_unary_sfpu_mask.h | 3 + .../unary/device/unary_composite_op.cpp | 2 +- 5 files changed, 27 insertions(+), 97 deletions(-) diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_mask.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_mask.h index 342edf8b23f..5b314ad8d9a 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_mask.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_mask.h @@ -20,7 +20,23 @@ inline void calculate_mask() { #pragma GCC unroll 8 for (int d = 0; d < ITERATIONS; d++) { vFloat mask = dst_reg[mask_val_idx]; - v_if(_sfpu_is_fp16_zero_(mask, exponent_size_8)) { dst_reg[0] = vConst0; } + v_if(_sfpu_is_fp16_zero_(mask, exponent_size_8)) { + dst_reg[0] = vConst0; + } + v_endif; + dst_reg++; + } +} + +template +inline void calculate_int_mask() { + const int mask_idx = 32; + #pragma GCC unroll 8 + for (int d = 0; d < ITERATIONS; d++) { + vInt mask = dst_reg[mask_idx]; + v_if (mask == 0) { + dst_reg[0] = vConst0; + } v_endif; dst_reg++; } diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_recip.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_recip.h index c4ad4b34288..ce5e51894ae 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_recip.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_recip.h @@ -12,70 +12,19 @@ using namespace sfpi; namespace ckernel { namespace sfpu { -template +template sfpi_inline vFloat sfpu_reciprocal(const vFloat in) { - // Force sign to 1 (make number negative) - vFloat val = setsgn(in, 1); - - val = setexp(val, 126); // Set exponent to 126 to make the number in 0.5-1 - // Use 1.44 as first guess at x, ideal value would be 1.33, but we happen to have 1.44 available, so use that to - // avoid a load - vFloat vConstLn2Recip = vConstFloatPrgm0; - - vFloat two; - if constexpr (save_reg) { - two = vConstFloatPrgm1; - } - - vFloat result = vConstLn2Recip * (val * vConstLn2Recip + (save_reg ? 2.0 : two)); - - for (int s_iter = 0; s_iter < (max_iter - 1); s_iter++) { - result = result * (val * result + (save_reg ? 2.0 : two)); - } - - vInt orig_exp = exexp(in); - vInt new_exp = exexp(result); - - // "Subtract" exponents, and re-bias. - // Execute: -1 - exp, then exp += 127 - new_exp -= orig_exp; - new_exp += 126; - - v_if(new_exp < 0) { - // If rebiased exponent is negative, we need to saturate at 0. - // This means the initial number was too big so reciprocal result should be 0 - result = 0.0F; - new_exp = 0; - } - v_endif; - - // Set newly denormalized exponent to result exponent field - return setexp(result, new_exp); + return _sfpu_reciprocal_(in); } template inline void calculate_reciprocal() { -#pragma GCC unroll 8 - for (int d = 0; d < ITERATIONS; d++) { - vFloat in = dst_reg[0]; - vFloat out = sfpu_reciprocal < APPROXIMATION_MODE ? 2 : 3, true > (in); - - v_if(in < 0.0F) { - // Invert sign on calculated value if CC=1 (number is negative) - out = -out; - } - v_endif; - - dst_reg[0] = out; - - dst_reg++; - } + _calculate_reciprocal_(ITERATIONS); } template void recip_init() { - vConstFloatPrgm0 = 1.442695f; // ln2_recip - vConstFloatPrgm1 = 2.0f; + _init_reciprocal_(); } } // namespace sfpu diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sqrt.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sqrt.h index 717b92723c6..0030c495e56 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sqrt.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sqrt.h @@ -14,50 +14,12 @@ namespace sfpu { template inline void calculate_sqrt() { -#pragma GCC unroll 8 - for (int d = 0; d < ITERATIONS; d++) { - vFloat val = dst_reg[0]; - - if constexpr (APPROXIMATION_MODE) { - vUInt magic = vConstIntPrgm0; - - // sqrt initial approximation - // adjust bias - vUInt val_s = magic + reinterpret(val); - - // approximation of square root - val_s >>= 1; - dst_reg[0] = reinterpret(val_s); - } else { - // Recip root method - //// Init approx - // u.i = SQRT_MAGIC_F - (u.i >> 1); - v_if(val != 0.0f) { - vUInt magic = vConstIntPrgm0; - vFloat approx = reinterpret(magic - (reinterpret(val) >> 1)); - - // Reciproot iterations - for (int r = 0; r < RECIPROCAL_ITERATIONS; r++) { - // x*r*(1.5f - xhalf*r*r); - approx = ((approx * approx) * (val * -0.5f) + 1.5f) * approx; - } - - dst_reg[0] = approx * val; - } - v_endif; - } - - dst_reg++; - } + _calculate_sqrt_(ITERATIONS); } template void sqrt_init() { - if (APPROXIMATION_MODE) { - vConstFloatPrgm0 = s2vFloat16b(127 << 7); - } else { - vConstFloatPrgm0 = s2vFloat16b(0x5f37); - } + _init_sqrt_(); } } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_mask.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_mask.h index 82831f2a995..172b87b9dc8 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_mask.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_mask.h @@ -22,6 +22,9 @@ inline void llk_math_eltwise_unary_sfpu_mask(uint dst_index, DataFormat data_for if (data_format == DataFormat::Float16_b || data_format == DataFormat::Float16) { llk_math_eltwise_unary_sfpu_params( ckernel::sfpu::calculate_mask, dst_index, vector_mode); + } else if (data_format == DataFormat::Int32) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_int_mask, dst_index, vector_mode); } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 91756bd3a6d..107dc8476d7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -367,7 +367,7 @@ Tensor _swish(const Tensor& a, const std::optional& output_mem_con Tensor _trunc(const Tensor& input, const std::optional& output_mem_config) { auto arch = input.device()->arch(); - TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole"); + TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Op is not supported on Grayskull"); Tensor floor_res = ttnn::floor(input, output_mem_config); Tensor trunc_res = ttnn::where(ttnn::ne(input, floor_res), ttnn::add(floor_res, 1.0f, std::nullopt, output_mem_config), floor_res); Tensor result = ttnn::where(ttnn::gtz(input, output_mem_config), floor_res, trunc_res);