diff --git a/sxt/field12/base/BUILD b/sxt/field12/base/BUILD index 0c43e6868..9bfbda896 100644 --- a/sxt/field12/base/BUILD +++ b/sxt/field12/base/BUILD @@ -46,9 +46,6 @@ sxt_cc_component( sxt_cc_component( name = "reduce", impl_deps = [ - ":constants", - ":subtract_p", - "//sxt/base/field:arithmetic_utility", "//sxt/base/type:narrow_cast", ], is_cuda = True, @@ -57,6 +54,9 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], deps = [ + ":constants", + ":subtract_p", + "//sxt/base/field:arithmetic_utility", "//sxt/base/macro:cuda_callable", ], ) diff --git a/sxt/field12/base/reduce.cc b/sxt/field12/base/reduce.cc index 15fdc5069..95e6fe2a8 100644 --- a/sxt/field12/base/reduce.cc +++ b/sxt/field12/base/reduce.cc @@ -25,85 +25,9 @@ */ #include "sxt/field12/base/reduce.h" -#include "sxt/base/field/arithmetic_utility.h" #include "sxt/base/type/narrow_cast.h" -#include "sxt/field12/base/constants.h" -#include "sxt/field12/base/subtract_p.h" namespace sxt::f12b { -//-------------------------------------------------------------------------------------------------- -// reduce -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE void reduce(uint64_t h[6], const uint64_t t[12]) noexcept { - uint64_t tmp = 0; - uint64_t carry = 0; - uint64_t ret[12]; - - uint64_t k = t[0] * inv_v; - basfld::mac(tmp, carry, t[0], k, p_v[0]); - basfld::mac(ret[1], carry, t[1], k, p_v[1]); - basfld::mac(ret[2], carry, t[2], k, p_v[2]); - basfld::mac(ret[3], carry, t[3], k, p_v[3]); - basfld::mac(ret[4], carry, t[4], k, p_v[4]); - basfld::mac(ret[5], carry, t[5], k, p_v[5]); - basfld::adc(ret[6], ret[7], t[6], 0, carry); - - carry = 0; - k = ret[1] * inv_v; - basfld::mac(tmp, carry, ret[1], k, p_v[0]); - basfld::mac(ret[2], carry, ret[2], k, p_v[1]); - basfld::mac(ret[3], carry, ret[3], k, p_v[2]); - basfld::mac(ret[4], carry, ret[4], k, p_v[3]); - basfld::mac(ret[5], carry, ret[5], k, p_v[4]); - basfld::mac(ret[6], carry, ret[6], k, p_v[5]); - basfld::adc(ret[7], ret[8], t[7], ret[7], carry); - - carry = 0; - k = ret[2] * inv_v; - basfld::mac(tmp, carry, ret[2], k, p_v[0]); - basfld::mac(ret[3], carry, ret[3], k, p_v[1]); - basfld::mac(ret[4], carry, ret[4], k, p_v[2]); - basfld::mac(ret[5], carry, ret[5], k, p_v[3]); - basfld::mac(ret[6], carry, ret[6], k, p_v[4]); - basfld::mac(ret[7], carry, ret[7], k, p_v[5]); - basfld::adc(ret[8], ret[9], t[8], ret[8], carry); - - carry = 0; - k = ret[3] * inv_v; - basfld::mac(tmp, carry, ret[3], k, p_v[0]); - basfld::mac(ret[4], carry, ret[4], k, p_v[1]); - basfld::mac(ret[5], carry, ret[5], k, p_v[2]); - basfld::mac(ret[6], carry, ret[6], k, p_v[3]); - basfld::mac(ret[7], carry, ret[7], k, p_v[4]); - basfld::mac(ret[8], carry, ret[8], k, p_v[5]); - basfld::adc(ret[9], ret[10], t[9], ret[9], carry); - - carry = 0; - k = ret[4] * inv_v; - basfld::mac(tmp, carry, ret[4], k, p_v[0]); - basfld::mac(ret[5], carry, ret[5], k, p_v[1]); - basfld::mac(ret[6], carry, ret[6], k, p_v[2]); - basfld::mac(ret[7], carry, ret[7], k, p_v[3]); - basfld::mac(ret[8], carry, ret[8], k, p_v[4]); - basfld::mac(ret[9], carry, ret[9], k, p_v[5]); - basfld::adc(ret[10], ret[11], t[10], ret[10], carry); - - carry = 0; - k = ret[5] * inv_v; - basfld::mac(tmp, carry, ret[5], k, p_v[0]); - basfld::mac(ret[6], carry, ret[6], k, p_v[1]); - basfld::mac(ret[7], carry, ret[7], k, p_v[2]); - basfld::mac(ret[8], carry, ret[8], k, p_v[3]); - basfld::mac(ret[9], carry, ret[9], k, p_v[4]); - basfld::mac(ret[10], carry, ret[10], k, p_v[5]); - basfld::adc(ret[11], tmp, t[11], ret[11], carry); - - // Attempt to subtract the modulus, - // to ensure the value is smaller than the modulus. - uint64_t a[6] = {ret[6], ret[7], ret[8], ret[9], ret[10], ret[11]}; - subtract_p(h, a); -} - //-------------------------------------------------------------------------------------------------- // is_below_modulus //-------------------------------------------------------------------------------------------------- diff --git a/sxt/field12/base/reduce.h b/sxt/field12/base/reduce.h index 60c62b11f..8f60a0944 100644 --- a/sxt/field12/base/reduce.h +++ b/sxt/field12/base/reduce.h @@ -27,7 +27,10 @@ #include +#include "sxt/base/field/arithmetic_utility.h" #include "sxt/base/macro/cuda_callable.h" +#include "sxt/field12/base/constants.h" +#include "sxt/field12/base/subtract_p.h" namespace sxt::f12b { //-------------------------------------------------------------------------------------------------- @@ -38,7 +41,75 @@ namespace sxt::f12b { * Handbook of Applied Cryptography * . */ -CUDA_CALLABLE void reduce(uint64_t h[6], const uint64_t t[12]) noexcept; +CUDA_CALLABLE inline void reduce(uint64_t h[6], const uint64_t t[12]) noexcept { + uint64_t tmp = 0; + uint64_t carry = 0; + uint64_t ret[12]; + + uint64_t k = t[0] * inv_v; + basfld::mac(tmp, carry, t[0], k, p_v[0]); + basfld::mac(ret[1], carry, t[1], k, p_v[1]); + basfld::mac(ret[2], carry, t[2], k, p_v[2]); + basfld::mac(ret[3], carry, t[3], k, p_v[3]); + basfld::mac(ret[4], carry, t[4], k, p_v[4]); + basfld::mac(ret[5], carry, t[5], k, p_v[5]); + basfld::adc(ret[6], ret[7], t[6], 0, carry); + + carry = 0; + k = ret[1] * inv_v; + basfld::mac(tmp, carry, ret[1], k, p_v[0]); + basfld::mac(ret[2], carry, ret[2], k, p_v[1]); + basfld::mac(ret[3], carry, ret[3], k, p_v[2]); + basfld::mac(ret[4], carry, ret[4], k, p_v[3]); + basfld::mac(ret[5], carry, ret[5], k, p_v[4]); + basfld::mac(ret[6], carry, ret[6], k, p_v[5]); + basfld::adc(ret[7], ret[8], t[7], ret[7], carry); + + carry = 0; + k = ret[2] * inv_v; + basfld::mac(tmp, carry, ret[2], k, p_v[0]); + basfld::mac(ret[3], carry, ret[3], k, p_v[1]); + basfld::mac(ret[4], carry, ret[4], k, p_v[2]); + basfld::mac(ret[5], carry, ret[5], k, p_v[3]); + basfld::mac(ret[6], carry, ret[6], k, p_v[4]); + basfld::mac(ret[7], carry, ret[7], k, p_v[5]); + basfld::adc(ret[8], ret[9], t[8], ret[8], carry); + + carry = 0; + k = ret[3] * inv_v; + basfld::mac(tmp, carry, ret[3], k, p_v[0]); + basfld::mac(ret[4], carry, ret[4], k, p_v[1]); + basfld::mac(ret[5], carry, ret[5], k, p_v[2]); + basfld::mac(ret[6], carry, ret[6], k, p_v[3]); + basfld::mac(ret[7], carry, ret[7], k, p_v[4]); + basfld::mac(ret[8], carry, ret[8], k, p_v[5]); + basfld::adc(ret[9], ret[10], t[9], ret[9], carry); + + carry = 0; + k = ret[4] * inv_v; + basfld::mac(tmp, carry, ret[4], k, p_v[0]); + basfld::mac(ret[5], carry, ret[5], k, p_v[1]); + basfld::mac(ret[6], carry, ret[6], k, p_v[2]); + basfld::mac(ret[7], carry, ret[7], k, p_v[3]); + basfld::mac(ret[8], carry, ret[8], k, p_v[4]); + basfld::mac(ret[9], carry, ret[9], k, p_v[5]); + basfld::adc(ret[10], ret[11], t[10], ret[10], carry); + + carry = 0; + k = ret[5] * inv_v; + basfld::mac(tmp, carry, ret[5], k, p_v[0]); + basfld::mac(ret[6], carry, ret[6], k, p_v[1]); + basfld::mac(ret[7], carry, ret[7], k, p_v[2]); + basfld::mac(ret[8], carry, ret[8], k, p_v[3]); + basfld::mac(ret[9], carry, ret[9], k, p_v[4]); + basfld::mac(ret[10], carry, ret[10], k, p_v[5]); + basfld::adc(ret[11], tmp, t[11], ret[11], carry); + + // Attempt to subtract the modulus, + // to ensure the value is smaller than the modulus. + uint64_t a[6] = {ret[6], ret[7], ret[8], ret[9], ret[10], ret[11]}; + subtract_p(h, a); +} //-------------------------------------------------------------------------------------------------- // is_below_modulus diff --git a/sxt/field25/base/BUILD b/sxt/field25/base/BUILD index 0c43e6868..9bfbda896 100644 --- a/sxt/field25/base/BUILD +++ b/sxt/field25/base/BUILD @@ -46,9 +46,6 @@ sxt_cc_component( sxt_cc_component( name = "reduce", impl_deps = [ - ":constants", - ":subtract_p", - "//sxt/base/field:arithmetic_utility", "//sxt/base/type:narrow_cast", ], is_cuda = True, @@ -57,6 +54,9 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], deps = [ + ":constants", + ":subtract_p", + "//sxt/base/field:arithmetic_utility", "//sxt/base/macro:cuda_callable", ], ) diff --git a/sxt/field25/base/reduce.cc b/sxt/field25/base/reduce.cc index 31a491e8a..719a8cd14 100644 --- a/sxt/field25/base/reduce.cc +++ b/sxt/field25/base/reduce.cc @@ -25,57 +25,9 @@ */ #include "sxt/field25/base/reduce.h" -#include "sxt/base/field/arithmetic_utility.h" #include "sxt/base/type/narrow_cast.h" -#include "sxt/field25/base/constants.h" -#include "sxt/field25/base/subtract_p.h" namespace sxt::f25b { -//-------------------------------------------------------------------------------------------------- -// reduce -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE void reduce(uint64_t h[4], const uint64_t t[8]) noexcept { - uint64_t tmp = 0; - uint64_t carry = 0; - uint64_t ret[8]; - - uint64_t k = t[0] * inv_v; - basfld::mac(tmp, carry, t[0], k, p_v[0]); - basfld::mac(ret[1], carry, t[1], k, p_v[1]); - basfld::mac(ret[2], carry, t[2], k, p_v[2]); - basfld::mac(ret[3], carry, t[3], k, p_v[3]); - basfld::adc(ret[4], ret[5], t[4], 0, carry); - - carry = 0; - k = ret[1] * inv_v; - basfld::mac(tmp, carry, ret[1], k, p_v[0]); - basfld::mac(ret[2], carry, ret[2], k, p_v[1]); - basfld::mac(ret[3], carry, ret[3], k, p_v[2]); - basfld::mac(ret[4], carry, ret[4], k, p_v[3]); - basfld::adc(ret[5], ret[6], t[5], ret[5], carry); - - carry = 0; - k = ret[2] * inv_v; - basfld::mac(tmp, carry, ret[2], k, p_v[0]); - basfld::mac(ret[3], carry, ret[3], k, p_v[1]); - basfld::mac(ret[4], carry, ret[4], k, p_v[2]); - basfld::mac(ret[5], carry, ret[5], k, p_v[3]); - basfld::adc(ret[6], ret[7], t[6], ret[6], carry); - - carry = 0; - k = ret[3] * inv_v; - basfld::mac(tmp, carry, ret[3], k, p_v[0]); - basfld::mac(ret[4], carry, ret[4], k, p_v[1]); - basfld::mac(ret[5], carry, ret[5], k, p_v[2]); - basfld::mac(ret[6], carry, ret[6], k, p_v[3]); - basfld::adc(ret[7], tmp, t[7], ret[7], carry); - - // Attempt to subtract the modulus, - // to ensure the value is smaller than the modulus. - uint64_t a[4] = {ret[4], ret[5], ret[6], ret[7]}; - subtract_p(h, a); -} - //-------------------------------------------------------------------------------------------------- // is_below_modulus //-------------------------------------------------------------------------------------------------- diff --git a/sxt/field25/base/reduce.h b/sxt/field25/base/reduce.h index dd55d56bb..a36b229cb 100644 --- a/sxt/field25/base/reduce.h +++ b/sxt/field25/base/reduce.h @@ -27,7 +27,10 @@ #include +#include "sxt/base/field/arithmetic_utility.h" #include "sxt/base/macro/cuda_callable.h" +#include "sxt/field25/base/constants.h" +#include "sxt/field25/base/subtract_p.h" namespace sxt::f25b { //-------------------------------------------------------------------------------------------------- @@ -38,7 +41,47 @@ namespace sxt::f25b { * Handbook of Applied Cryptography * . */ -CUDA_CALLABLE void reduce(uint64_t h[4], const uint64_t t[8]) noexcept; +CUDA_CALLABLE inline void reduce(uint64_t h[4], const uint64_t t[8]) noexcept { + uint64_t tmp = 0; + uint64_t carry = 0; + uint64_t ret[8]; + + uint64_t k = t[0] * inv_v; + basfld::mac(tmp, carry, t[0], k, p_v[0]); + basfld::mac(ret[1], carry, t[1], k, p_v[1]); + basfld::mac(ret[2], carry, t[2], k, p_v[2]); + basfld::mac(ret[3], carry, t[3], k, p_v[3]); + basfld::adc(ret[4], ret[5], t[4], 0, carry); + + carry = 0; + k = ret[1] * inv_v; + basfld::mac(tmp, carry, ret[1], k, p_v[0]); + basfld::mac(ret[2], carry, ret[2], k, p_v[1]); + basfld::mac(ret[3], carry, ret[3], k, p_v[2]); + basfld::mac(ret[4], carry, ret[4], k, p_v[3]); + basfld::adc(ret[5], ret[6], t[5], ret[5], carry); + + carry = 0; + k = ret[2] * inv_v; + basfld::mac(tmp, carry, ret[2], k, p_v[0]); + basfld::mac(ret[3], carry, ret[3], k, p_v[1]); + basfld::mac(ret[4], carry, ret[4], k, p_v[2]); + basfld::mac(ret[5], carry, ret[5], k, p_v[3]); + basfld::adc(ret[6], ret[7], t[6], ret[6], carry); + + carry = 0; + k = ret[3] * inv_v; + basfld::mac(tmp, carry, ret[3], k, p_v[0]); + basfld::mac(ret[4], carry, ret[4], k, p_v[1]); + basfld::mac(ret[5], carry, ret[5], k, p_v[2]); + basfld::mac(ret[6], carry, ret[6], k, p_v[3]); + basfld::adc(ret[7], tmp, t[7], ret[7], carry); + + // Attempt to subtract the modulus, + // to ensure the value is smaller than the modulus. + uint64_t a[4] = {ret[4], ret[5], ret[6], ret[7]}; + subtract_p(h, a); +} //-------------------------------------------------------------------------------------------------- // is_below_modulus