Skip to content

Commit

Permalink
#10226: Added or refactored BH SFPU functions
Browse files Browse the repository at this point in the history
1. SFPU mask - added _int version
2. SFPU reciprocal - modified to use submodule LLKs
3. SFPU sqrt - modified to use submodule LLKs
4. SFPU trunc OP - enabled for BH
  • Loading branch information
ncvetkovicTT committed Sep 10, 2024
1 parent e789c24 commit a0d89ce
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,23 @@ inline void calculate_mask() {
#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++) {
vFloat mask = dst_reg[mask_val_idx];
v_if(_sfpu_is_fp16_zero_(mask, exponent_size_8)) { dst_reg[0] = vConst0; }
v_if(_sfpu_is_fp16_zero_(mask, exponent_size_8)) {
dst_reg[0] = vConst0;
}
v_endif;
dst_reg++;
}
}

template <bool APPROXIMATION_MODE, int ITERATIONS=8>
inline void calculate_int_mask() {
const int mask_idx = 32;
#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++) {
vInt mask = dst_reg[mask_idx];
v_if (mask == 0) {
dst_reg[0] = vConst0;
}
v_endif;
dst_reg++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,19 @@ using namespace sfpi;
namespace ckernel {
namespace sfpu {

template <int max_iter = 3, bool save_reg = true>
template <int max_iter = 3, bool save_reg = true /* Unused. Enough registers available. */>
sfpi_inline vFloat sfpu_reciprocal(const vFloat in) {
// Force sign to 1 (make number negative)
vFloat val = setsgn(in, 1);

val = setexp(val, 126); // Set exponent to 126 to make the number in 0.5-1
// Use 1.44 as first guess at x, ideal value would be 1.33, but we happen to have 1.44 available, so use that to
// avoid a load
vFloat vConstLn2Recip = vConstFloatPrgm0;

vFloat two;
if constexpr (save_reg) {
two = vConstFloatPrgm1;
}

vFloat result = vConstLn2Recip * (val * vConstLn2Recip + (save_reg ? 2.0 : two));

for (int s_iter = 0; s_iter < (max_iter - 1); s_iter++) {
result = result * (val * result + (save_reg ? 2.0 : two));
}

vInt orig_exp = exexp(in);
vInt new_exp = exexp(result);

// "Subtract" exponents, and re-bias.
// Execute: -1 - exp, then exp += 127
new_exp -= orig_exp;
new_exp += 126;

v_if(new_exp < 0) {
// If rebiased exponent is negative, we need to saturate at 0.
// This means the initial number was too big so reciprocal result should be 0
result = 0.0F;
new_exp = 0;
}
v_endif;

// Set newly denormalized exponent to result exponent field
return setexp(result, new_exp);
return _sfpu_reciprocal_<max_iter>(in);
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_reciprocal() {
#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++) {
vFloat in = dst_reg[0];
vFloat out = sfpu_reciprocal < APPROXIMATION_MODE ? 2 : 3, true > (in);

v_if(in < 0.0F) {
// Invert sign on calculated value if CC=1 (number is negative)
out = -out;
}
v_endif;

dst_reg[0] = out;

dst_reg++;
}
_calculate_reciprocal_<APPROXIMATION_MODE, ITERATIONS>(ITERATIONS);
}

template <bool APPROXIMATION_MODE>
void recip_init() {
vConstFloatPrgm0 = 1.442695f; // ln2_recip
vConstFloatPrgm1 = 2.0f;
_init_reciprocal_<APPROXIMATION_MODE>();
}

} // namespace sfpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,50 +14,12 @@ namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8, int RECIPROCAL_ITERATIONS = 2>
inline void calculate_sqrt() {
#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++) {
vFloat val = dst_reg[0];

if constexpr (APPROXIMATION_MODE) {
vUInt magic = vConstIntPrgm0;

// sqrt initial approximation
// adjust bias
vUInt val_s = magic + reinterpret<vUInt>(val);

// approximation of square root
val_s >>= 1;
dst_reg[0] = reinterpret<vFloat>(val_s);
} else {
// Recip root method
//// Init approx
// u.i = SQRT_MAGIC_F - (u.i >> 1);
v_if(val != 0.0f) {
vUInt magic = vConstIntPrgm0;
vFloat approx = reinterpret<vFloat>(magic - (reinterpret<vUInt>(val) >> 1));

// Reciproot iterations
for (int r = 0; r < RECIPROCAL_ITERATIONS; r++) {
// x*r*(1.5f - xhalf*r*r);
approx = ((approx * approx) * (val * -0.5f) + 1.5f) * approx;
}

dst_reg[0] = approx * val;
}
v_endif;
}

dst_reg++;
}
_calculate_sqrt_<APPROXIMATION_MODE, ITERATIONS, RECIPROCAL_ITERATIONS>(ITERATIONS);
}

template <bool APPROXIMATION_MODE>
void sqrt_init() {
if (APPROXIMATION_MODE) {
vConstFloatPrgm0 = s2vFloat16b(127 << 7);
} else {
vConstFloatPrgm0 = s2vFloat16b(0x5f37);
}
_init_sqrt_<APPROXIMATION_MODE>();
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ inline void llk_math_eltwise_unary_sfpu_mask(uint dst_index, DataFormat data_for
if (data_format == DataFormat::Float16_b || data_format == DataFormat::Float16) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_mask<APPROXIMATE>, dst_index, vector_mode);
} else if (data_format == DataFormat::Int32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_int_mask<APPROXIMATE>, dst_index, vector_mode);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ Tensor _swish(const Tensor& a, const std::optional<MemoryConfig>& output_mem_con

Tensor _trunc(const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
auto arch = input.device()->arch();
TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole");
TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Op is not supported on Grayskull");
Tensor floor_res = ttnn::floor(input, output_mem_config);
Tensor trunc_res = ttnn::where(ttnn::ne(input, floor_res), ttnn::add(floor_res, 1.0f, std::nullopt, output_mem_config), floor_res);
Tensor result = ttnn::where(ttnn::gtz(input, output_mem_config), floor_res, trunc_res);
Expand Down

0 comments on commit a0d89ce

Please sign in to comment.