Skip to content

Commit

Permalink
refactor: inline reduce in the field12 and field25 base packages (PRO…
Browse files Browse the repository at this point in the history
…OF-823) (#108)

* inline reduce in the field12 base package

* inline reduce in the field25 base package
  • Loading branch information
jacobtrombetta authored Mar 18, 2024
1 parent eb11456 commit 88e190d
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 132 deletions.
6 changes: 3 additions & 3 deletions sxt/field12/base/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ sxt_cc_component(
sxt_cc_component(
name = "reduce",
impl_deps = [
":constants",
":subtract_p",
"//sxt/base/field:arithmetic_utility",
"//sxt/base/type:narrow_cast",
],
is_cuda = True,
Expand All @@ -57,6 +54,9 @@ sxt_cc_component(
"//sxt/base/test:unit_test",
],
deps = [
":constants",
":subtract_p",
"//sxt/base/field:arithmetic_utility",
"//sxt/base/macro:cuda_callable",
],
)
Expand Down
76 changes: 0 additions & 76 deletions sxt/field12/base/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,85 +25,9 @@
*/
#include "sxt/field12/base/reduce.h"

#include "sxt/base/field/arithmetic_utility.h"
#include "sxt/base/type/narrow_cast.h"
#include "sxt/field12/base/constants.h"
#include "sxt/field12/base/subtract_p.h"

namespace sxt::f12b {
//--------------------------------------------------------------------------------------------------
// reduce
//--------------------------------------------------------------------------------------------------
CUDA_CALLABLE void reduce(uint64_t h[6], const uint64_t t[12]) noexcept {
uint64_t tmp = 0;
uint64_t carry = 0;
uint64_t ret[12];

uint64_t k = t[0] * inv_v;
basfld::mac(tmp, carry, t[0], k, p_v[0]);
basfld::mac(ret[1], carry, t[1], k, p_v[1]);
basfld::mac(ret[2], carry, t[2], k, p_v[2]);
basfld::mac(ret[3], carry, t[3], k, p_v[3]);
basfld::mac(ret[4], carry, t[4], k, p_v[4]);
basfld::mac(ret[5], carry, t[5], k, p_v[5]);
basfld::adc(ret[6], ret[7], t[6], 0, carry);

carry = 0;
k = ret[1] * inv_v;
basfld::mac(tmp, carry, ret[1], k, p_v[0]);
basfld::mac(ret[2], carry, ret[2], k, p_v[1]);
basfld::mac(ret[3], carry, ret[3], k, p_v[2]);
basfld::mac(ret[4], carry, ret[4], k, p_v[3]);
basfld::mac(ret[5], carry, ret[5], k, p_v[4]);
basfld::mac(ret[6], carry, ret[6], k, p_v[5]);
basfld::adc(ret[7], ret[8], t[7], ret[7], carry);

carry = 0;
k = ret[2] * inv_v;
basfld::mac(tmp, carry, ret[2], k, p_v[0]);
basfld::mac(ret[3], carry, ret[3], k, p_v[1]);
basfld::mac(ret[4], carry, ret[4], k, p_v[2]);
basfld::mac(ret[5], carry, ret[5], k, p_v[3]);
basfld::mac(ret[6], carry, ret[6], k, p_v[4]);
basfld::mac(ret[7], carry, ret[7], k, p_v[5]);
basfld::adc(ret[8], ret[9], t[8], ret[8], carry);

carry = 0;
k = ret[3] * inv_v;
basfld::mac(tmp, carry, ret[3], k, p_v[0]);
basfld::mac(ret[4], carry, ret[4], k, p_v[1]);
basfld::mac(ret[5], carry, ret[5], k, p_v[2]);
basfld::mac(ret[6], carry, ret[6], k, p_v[3]);
basfld::mac(ret[7], carry, ret[7], k, p_v[4]);
basfld::mac(ret[8], carry, ret[8], k, p_v[5]);
basfld::adc(ret[9], ret[10], t[9], ret[9], carry);

carry = 0;
k = ret[4] * inv_v;
basfld::mac(tmp, carry, ret[4], k, p_v[0]);
basfld::mac(ret[5], carry, ret[5], k, p_v[1]);
basfld::mac(ret[6], carry, ret[6], k, p_v[2]);
basfld::mac(ret[7], carry, ret[7], k, p_v[3]);
basfld::mac(ret[8], carry, ret[8], k, p_v[4]);
basfld::mac(ret[9], carry, ret[9], k, p_v[5]);
basfld::adc(ret[10], ret[11], t[10], ret[10], carry);

carry = 0;
k = ret[5] * inv_v;
basfld::mac(tmp, carry, ret[5], k, p_v[0]);
basfld::mac(ret[6], carry, ret[6], k, p_v[1]);
basfld::mac(ret[7], carry, ret[7], k, p_v[2]);
basfld::mac(ret[8], carry, ret[8], k, p_v[3]);
basfld::mac(ret[9], carry, ret[9], k, p_v[4]);
basfld::mac(ret[10], carry, ret[10], k, p_v[5]);
basfld::adc(ret[11], tmp, t[11], ret[11], carry);

// Attempt to subtract the modulus,
// to ensure the value is smaller than the modulus.
uint64_t a[6] = {ret[6], ret[7], ret[8], ret[9], ret[10], ret[11]};
subtract_p(h, a);
}

//--------------------------------------------------------------------------------------------------
// is_below_modulus
//--------------------------------------------------------------------------------------------------
Expand Down
73 changes: 72 additions & 1 deletion sxt/field12/base/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

#include <cstdint>

#include "sxt/base/field/arithmetic_utility.h"
#include "sxt/base/macro/cuda_callable.h"
#include "sxt/field12/base/constants.h"
#include "sxt/field12/base/subtract_p.h"

namespace sxt::f12b {
//--------------------------------------------------------------------------------------------------
Expand All @@ -38,7 +41,75 @@ namespace sxt::f12b {
* Handbook of Applied Cryptography
* <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
*/
CUDA_CALLABLE void reduce(uint64_t h[6], const uint64_t t[12]) noexcept;
CUDA_CALLABLE inline void reduce(uint64_t h[6], const uint64_t t[12]) noexcept {
uint64_t tmp = 0;
uint64_t carry = 0;
uint64_t ret[12];

uint64_t k = t[0] * inv_v;
basfld::mac(tmp, carry, t[0], k, p_v[0]);
basfld::mac(ret[1], carry, t[1], k, p_v[1]);
basfld::mac(ret[2], carry, t[2], k, p_v[2]);
basfld::mac(ret[3], carry, t[3], k, p_v[3]);
basfld::mac(ret[4], carry, t[4], k, p_v[4]);
basfld::mac(ret[5], carry, t[5], k, p_v[5]);
basfld::adc(ret[6], ret[7], t[6], 0, carry);

carry = 0;
k = ret[1] * inv_v;
basfld::mac(tmp, carry, ret[1], k, p_v[0]);
basfld::mac(ret[2], carry, ret[2], k, p_v[1]);
basfld::mac(ret[3], carry, ret[3], k, p_v[2]);
basfld::mac(ret[4], carry, ret[4], k, p_v[3]);
basfld::mac(ret[5], carry, ret[5], k, p_v[4]);
basfld::mac(ret[6], carry, ret[6], k, p_v[5]);
basfld::adc(ret[7], ret[8], t[7], ret[7], carry);

carry = 0;
k = ret[2] * inv_v;
basfld::mac(tmp, carry, ret[2], k, p_v[0]);
basfld::mac(ret[3], carry, ret[3], k, p_v[1]);
basfld::mac(ret[4], carry, ret[4], k, p_v[2]);
basfld::mac(ret[5], carry, ret[5], k, p_v[3]);
basfld::mac(ret[6], carry, ret[6], k, p_v[4]);
basfld::mac(ret[7], carry, ret[7], k, p_v[5]);
basfld::adc(ret[8], ret[9], t[8], ret[8], carry);

carry = 0;
k = ret[3] * inv_v;
basfld::mac(tmp, carry, ret[3], k, p_v[0]);
basfld::mac(ret[4], carry, ret[4], k, p_v[1]);
basfld::mac(ret[5], carry, ret[5], k, p_v[2]);
basfld::mac(ret[6], carry, ret[6], k, p_v[3]);
basfld::mac(ret[7], carry, ret[7], k, p_v[4]);
basfld::mac(ret[8], carry, ret[8], k, p_v[5]);
basfld::adc(ret[9], ret[10], t[9], ret[9], carry);

carry = 0;
k = ret[4] * inv_v;
basfld::mac(tmp, carry, ret[4], k, p_v[0]);
basfld::mac(ret[5], carry, ret[5], k, p_v[1]);
basfld::mac(ret[6], carry, ret[6], k, p_v[2]);
basfld::mac(ret[7], carry, ret[7], k, p_v[3]);
basfld::mac(ret[8], carry, ret[8], k, p_v[4]);
basfld::mac(ret[9], carry, ret[9], k, p_v[5]);
basfld::adc(ret[10], ret[11], t[10], ret[10], carry);

carry = 0;
k = ret[5] * inv_v;
basfld::mac(tmp, carry, ret[5], k, p_v[0]);
basfld::mac(ret[6], carry, ret[6], k, p_v[1]);
basfld::mac(ret[7], carry, ret[7], k, p_v[2]);
basfld::mac(ret[8], carry, ret[8], k, p_v[3]);
basfld::mac(ret[9], carry, ret[9], k, p_v[4]);
basfld::mac(ret[10], carry, ret[10], k, p_v[5]);
basfld::adc(ret[11], tmp, t[11], ret[11], carry);

// Attempt to subtract the modulus,
// to ensure the value is smaller than the modulus.
uint64_t a[6] = {ret[6], ret[7], ret[8], ret[9], ret[10], ret[11]};
subtract_p(h, a);
}

//--------------------------------------------------------------------------------------------------
// is_below_modulus
Expand Down
6 changes: 3 additions & 3 deletions sxt/field25/base/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ sxt_cc_component(
sxt_cc_component(
name = "reduce",
impl_deps = [
":constants",
":subtract_p",
"//sxt/base/field:arithmetic_utility",
"//sxt/base/type:narrow_cast",
],
is_cuda = True,
Expand All @@ -57,6 +54,9 @@ sxt_cc_component(
"//sxt/base/test:unit_test",
],
deps = [
":constants",
":subtract_p",
"//sxt/base/field:arithmetic_utility",
"//sxt/base/macro:cuda_callable",
],
)
Expand Down
48 changes: 0 additions & 48 deletions sxt/field25/base/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,57 +25,9 @@
*/
#include "sxt/field25/base/reduce.h"

#include "sxt/base/field/arithmetic_utility.h"
#include "sxt/base/type/narrow_cast.h"
#include "sxt/field25/base/constants.h"
#include "sxt/field25/base/subtract_p.h"

namespace sxt::f25b {
//--------------------------------------------------------------------------------------------------
// reduce
//--------------------------------------------------------------------------------------------------
CUDA_CALLABLE void reduce(uint64_t h[4], const uint64_t t[8]) noexcept {
uint64_t tmp = 0;
uint64_t carry = 0;
uint64_t ret[8];

uint64_t k = t[0] * inv_v;
basfld::mac(tmp, carry, t[0], k, p_v[0]);
basfld::mac(ret[1], carry, t[1], k, p_v[1]);
basfld::mac(ret[2], carry, t[2], k, p_v[2]);
basfld::mac(ret[3], carry, t[3], k, p_v[3]);
basfld::adc(ret[4], ret[5], t[4], 0, carry);

carry = 0;
k = ret[1] * inv_v;
basfld::mac(tmp, carry, ret[1], k, p_v[0]);
basfld::mac(ret[2], carry, ret[2], k, p_v[1]);
basfld::mac(ret[3], carry, ret[3], k, p_v[2]);
basfld::mac(ret[4], carry, ret[4], k, p_v[3]);
basfld::adc(ret[5], ret[6], t[5], ret[5], carry);

carry = 0;
k = ret[2] * inv_v;
basfld::mac(tmp, carry, ret[2], k, p_v[0]);
basfld::mac(ret[3], carry, ret[3], k, p_v[1]);
basfld::mac(ret[4], carry, ret[4], k, p_v[2]);
basfld::mac(ret[5], carry, ret[5], k, p_v[3]);
basfld::adc(ret[6], ret[7], t[6], ret[6], carry);

carry = 0;
k = ret[3] * inv_v;
basfld::mac(tmp, carry, ret[3], k, p_v[0]);
basfld::mac(ret[4], carry, ret[4], k, p_v[1]);
basfld::mac(ret[5], carry, ret[5], k, p_v[2]);
basfld::mac(ret[6], carry, ret[6], k, p_v[3]);
basfld::adc(ret[7], tmp, t[7], ret[7], carry);

// Attempt to subtract the modulus,
// to ensure the value is smaller than the modulus.
uint64_t a[4] = {ret[4], ret[5], ret[6], ret[7]};
subtract_p(h, a);
}

//--------------------------------------------------------------------------------------------------
// is_below_modulus
//--------------------------------------------------------------------------------------------------
Expand Down
45 changes: 44 additions & 1 deletion sxt/field25/base/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

#include <cstdint>

#include "sxt/base/field/arithmetic_utility.h"
#include "sxt/base/macro/cuda_callable.h"
#include "sxt/field25/base/constants.h"
#include "sxt/field25/base/subtract_p.h"

namespace sxt::f25b {
//--------------------------------------------------------------------------------------------------
Expand All @@ -38,7 +41,47 @@ namespace sxt::f25b {
* Handbook of Applied Cryptography
* <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
*/
CUDA_CALLABLE void reduce(uint64_t h[4], const uint64_t t[8]) noexcept;
CUDA_CALLABLE inline void reduce(uint64_t h[4], const uint64_t t[8]) noexcept {
uint64_t tmp = 0;
uint64_t carry = 0;
uint64_t ret[8];

uint64_t k = t[0] * inv_v;
basfld::mac(tmp, carry, t[0], k, p_v[0]);
basfld::mac(ret[1], carry, t[1], k, p_v[1]);
basfld::mac(ret[2], carry, t[2], k, p_v[2]);
basfld::mac(ret[3], carry, t[3], k, p_v[3]);
basfld::adc(ret[4], ret[5], t[4], 0, carry);

carry = 0;
k = ret[1] * inv_v;
basfld::mac(tmp, carry, ret[1], k, p_v[0]);
basfld::mac(ret[2], carry, ret[2], k, p_v[1]);
basfld::mac(ret[3], carry, ret[3], k, p_v[2]);
basfld::mac(ret[4], carry, ret[4], k, p_v[3]);
basfld::adc(ret[5], ret[6], t[5], ret[5], carry);

carry = 0;
k = ret[2] * inv_v;
basfld::mac(tmp, carry, ret[2], k, p_v[0]);
basfld::mac(ret[3], carry, ret[3], k, p_v[1]);
basfld::mac(ret[4], carry, ret[4], k, p_v[2]);
basfld::mac(ret[5], carry, ret[5], k, p_v[3]);
basfld::adc(ret[6], ret[7], t[6], ret[6], carry);

carry = 0;
k = ret[3] * inv_v;
basfld::mac(tmp, carry, ret[3], k, p_v[0]);
basfld::mac(ret[4], carry, ret[4], k, p_v[1]);
basfld::mac(ret[5], carry, ret[5], k, p_v[2]);
basfld::mac(ret[6], carry, ret[6], k, p_v[3]);
basfld::adc(ret[7], tmp, t[7], ret[7], carry);

// Attempt to subtract the modulus,
// to ensure the value is smaller than the modulus.
uint64_t a[4] = {ret[4], ret[5], ret[6], ret[7]};
subtract_p(h, a);
}

//--------------------------------------------------------------------------------------------------
// is_below_modulus
Expand Down

0 comments on commit 88e190d

Please sign in to comment.