diff --git a/Cargo.toml b/Cargo.toml index f59c6667..22e2d429 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ serde_arrays = { version = "0.1.0", optional = true } hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } blake2b_simd = "1" rayon = "1.8" +unroll = "0.1.5" [features] default = ["bits"] diff --git a/src/arithmetic.rs b/src/arithmetic.rs index dc835177..04efdd8e 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -45,6 +45,30 @@ pub(crate) const fn macx(a: u64, b: u64, c: u64) -> (u64, u64) { (res as u64, (res >> 64) as u64) } +/// Returns a >= b +#[inline(always)] +pub(crate) const fn bigint_geq(a: &[u64; 4], b: &[u64; 4]) -> bool { + if a[3] > b[3] { + return true; + } else if a[3] < b[3] { + return false; + } + if a[2] > b[2] { + return true; + } else if a[2] < b[2] { + return false; + } + if a[1] > b[1] { + return true; + } else if a[1] < b[1] { + return false; + } + if a[0] >= b[0] { + return true; + } + false +} + /// Compute a * b, returning the result. #[inline(always)] pub(crate) fn mul_512(a: [u64; 4], b: [u64; 4]) -> [u64; 8] { diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index 2a2ce01c..8da18bbd 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -3,7 +3,7 @@ use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] use crate::{arithmetic::macx, field_arithmetic, field_specific}; -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, bigint_geq, mac, sbb}; use crate::extend_field_legendre; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 96dab322..22029e3a 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -18,7 +18,7 @@ pub use table::FR_TABLE; #[cfg(not(feature = "bn256-table"))] use crate::impl_from_u64; -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, bigint_geq, mac, sbb}; use crate::extend_field_legendre; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ diff --git a/src/derive/field.rs b/src/derive/field.rs index 0f9b7189..a6518b02 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -63,73 +63,88 @@ macro_rules! field_common { $crate::ff_ext::jacobi::jacobi::<5>(&self.0, &$modulus.0) } - #[cfg(feature = "asm")] const fn montgomery_form(val: [u64; 4], r: $field) -> $field { // Converts a 4 64-bit limb value into its congruent field representation. // If `val` represents a 256 bit value then `r` should be R^2, // if `val` represents the 256 MSB of a 512 bit value, then `r` should be R^3. - let (r0, carry) = mac(0, val[0], r.0[0], 0); - let (r1, carry) = mac(0, val[0], r.0[1], carry); - let (r2, carry) = mac(0, val[0], r.0[2], carry); - let (r3, r4) = mac(0, val[0], r.0[3], carry); - - let (r1, carry) = mac(r1, val[1], r.0[0], 0); - let (r2, carry) = mac(r2, val[1], r.0[1], carry); - let (r3, carry) = mac(r3, val[1], r.0[2], carry); - let (r4, r5) = mac(r4, val[1], r.0[3], carry); - - let (r2, carry) = mac(r2, val[2], r.0[0], 0); - let (r3, carry) = mac(r3, val[2], r.0[1], carry); - let (r4, carry) = mac(r4, val[2], r.0[2], carry); - let (r5, r6) = mac(r5, val[2], r.0[3], carry); - - let (r3, carry) = mac(r3, val[3], r.0[0], 0); - let (r4, carry) = mac(r4, val[3], r.0[1], carry); - let (r5, carry) = mac(r5, val[3], r.0[2], carry); - let (r6, r7) = mac(r6, val[3], r.0[3], carry); - - // Montgomery reduction - let k = r0.wrapping_mul($inv); - let (_, carry) = mac(r0, k, $modulus.0[0], 0); - let (r1, carry) = mac(r1, k, $modulus.0[1], carry); - let (r2, carry) = mac(r2, k, $modulus.0[2], carry); - let (r3, carry) = mac(r3, k, $modulus.0[3], carry); - let (r4, carry2) = adc(r4, 0, carry); - - let k = r1.wrapping_mul($inv); - let (_, carry) = mac(r1, k, $modulus.0[0], 0); - let (r2, carry) = mac(r2, k, $modulus.0[1], carry); - let (r3, carry) = mac(r3, k, $modulus.0[2], carry); - let (r4, carry) = mac(r4, k, $modulus.0[3], carry); - let (r5, carry2) = adc(r5, carry2, carry); - - let k = r2.wrapping_mul($inv); - let (_, carry) = mac(r2, k, $modulus.0[0], 0); - let (r3, carry) = mac(r3, k, $modulus.0[1], carry); - let (r4, carry) = mac(r4, k, $modulus.0[2], carry); - let (r5, carry) = mac(r5, k, $modulus.0[3], carry); - let (r6, carry2) = adc(r6, carry2, carry); - - let k = r3.wrapping_mul($inv); - let (_, carry) = mac(r3, k, $modulus.0[0], 0); - let (r4, carry) = mac(r4, k, $modulus.0[1], carry); - let (r5, carry) = mac(r5, k, $modulus.0[2], carry); - let (r6, carry) = mac(r6, k, $modulus.0[3], carry); - let (r7, carry2) = adc(r7, carry2, carry); - - // Result may be within MODULUS of the correct value - let (d0, borrow) = sbb(r4, $modulus.0[0], 0); - let (d1, borrow) = sbb(r5, $modulus.0[1], borrow); - let (d2, borrow) = sbb(r6, $modulus.0[2], borrow); - let (d3, borrow) = sbb(r7, $modulus.0[3], borrow); - let (_, borrow) = sbb(carry2, 0, borrow); - let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0); - let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry); - let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry); - let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry); + #[cfg(feature = "asm")] + { + let (r0, carry) = mac(0, val[0], r.0[0], 0); + let (r1, carry) = mac(0, val[0], r.0[1], carry); + let (r2, carry) = mac(0, val[0], r.0[2], carry); + let (r3, r4) = mac(0, val[0], r.0[3], carry); + + let (r1, carry) = mac(r1, val[1], r.0[0], 0); + let (r2, carry) = mac(r2, val[1], r.0[1], carry); + let (r3, carry) = mac(r3, val[1], r.0[2], carry); + let (r4, r5) = mac(r4, val[1], r.0[3], carry); + + let (r2, carry) = mac(r2, val[2], r.0[0], 0); + let (r3, carry) = mac(r3, val[2], r.0[1], carry); + let (r4, carry) = mac(r4, val[2], r.0[2], carry); + let (r5, r6) = mac(r5, val[2], r.0[3], carry); + + let (r3, carry) = mac(r3, val[3], r.0[0], 0); + let (r4, carry) = mac(r4, val[3], r.0[1], carry); + let (r5, carry) = mac(r5, val[3], r.0[2], carry); + let (r6, r7) = mac(r6, val[3], r.0[3], carry); + + // Montgomery reduction + let k = r0.wrapping_mul($inv); + let (_, carry) = mac(r0, k, $modulus.0[0], 0); + let (r1, carry) = mac(r1, k, $modulus.0[1], carry); + let (r2, carry) = mac(r2, k, $modulus.0[2], carry); + let (r3, carry) = mac(r3, k, $modulus.0[3], carry); + let (r4, carry2) = adc(r4, 0, carry); + + let k = r1.wrapping_mul($inv); + let (_, carry) = mac(r1, k, $modulus.0[0], 0); + let (r2, carry) = mac(r2, k, $modulus.0[1], carry); + let (r3, carry) = mac(r3, k, $modulus.0[2], carry); + let (r4, carry) = mac(r4, k, $modulus.0[3], carry); + let (r5, carry2) = adc(r5, carry2, carry); + + let k = r2.wrapping_mul($inv); + let (_, carry) = mac(r2, k, $modulus.0[0], 0); + let (r3, carry) = mac(r3, k, $modulus.0[1], carry); + let (r4, carry) = mac(r4, k, $modulus.0[2], carry); + let (r5, carry) = mac(r5, k, $modulus.0[3], carry); + let (r6, carry2) = adc(r6, carry2, carry); + + let k = r3.wrapping_mul($inv); + let (_, carry) = mac(r3, k, $modulus.0[0], 0); + let (r4, carry) = mac(r4, k, $modulus.0[1], carry); + let (r5, carry) = mac(r5, k, $modulus.0[2], carry); + let (r6, carry) = mac(r6, k, $modulus.0[3], carry); + let (r7, carry2) = adc(r7, carry2, carry); + + // Result may be within MODULUS of the correct value + let (d0, borrow) = sbb(r4, $modulus.0[0], 0); + let (d1, borrow) = sbb(r5, $modulus.0[1], borrow); + let (d2, borrow) = sbb(r6, $modulus.0[2], borrow); + let (d3, borrow) = sbb(r7, $modulus.0[3], borrow); + let (_, borrow) = sbb(carry2, 0, borrow); + let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0); + let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry); + let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry); + let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry); + + $field([d0, d1, d2, d3]) + } - $field([d0, d1, d2, d3]) + #[cfg(not(feature = "asm"))] + { + let mut val = val; + if bigint_geq(&val, &$modulus.0) { + let mut borrow = 0; + (val[0], borrow) = sbb(val[0], $modulus.0[0], borrow); + (val[1], borrow) = sbb(val[1], $modulus.0[1], borrow); + (val[2], borrow) = sbb(val[2], $modulus.0[2], borrow); + (val[3], _) = sbb(val[3], $modulus.0[3], borrow); + } + $field::mul(&$field(val), &r) + } } fn from_u512(limbs: [u64; 8]) -> $field { @@ -150,27 +165,13 @@ macro_rules! field_common { let lower_256 = [limbs[0], limbs[1], limbs[2], limbs[3]]; let upper_256 = [limbs[4], limbs[5], limbs[6], limbs[7]]; - #[cfg(feature = "asm")] - { - Self::montgomery_form(lower_256, $r2) + Self::montgomery_form(upper_256, $r3) - } - #[cfg(not(feature = "asm"))] - { - $field(lower_256) * $r2 + $field(upper_256) * $r3 - } + Self::montgomery_form(lower_256, $r2) + Self::montgomery_form(upper_256, $r3) } /// Converts from an integer represented in little endian /// into its (congruent) `$field` representation. pub const fn from_raw(val: [u64; 4]) -> Self { - #[cfg(feature = "asm")] - { - Self::montgomery_form(val, $r2) - } - #[cfg(not(feature = "asm"))] - { - (&$field(val)).mul(&$r2) - } + Self::montgomery_form(val, $r2) } /// Attempts to convert a little-endian byte representation of @@ -429,31 +430,69 @@ macro_rules! field_arithmetic { } /// Multiplies `rhs` by `self`, returning the result. - #[inline] - pub const fn mul(&self, rhs: &Self) -> $field { - // Schoolbook multiplication + #[inline(always)] + #[unroll::unroll_for_loops] + #[allow(unused_assignments)] + pub const fn mul(&self, rhs: &Self) -> Self { + // Fast Coarsely Integrated Operand Scanning (CIOS) as described + // in Algorithm 2 of EdMSM: https://eprint.iacr.org/2022/1400.pdf + // + // Cannot use the fast version (algorithm 2) if + // modulus_high_word >= (WORD_SIZE - 1) / 2 - 1 = (2^64 - 1)/2 - 1 + + if $modulus.0[3] < (u64::MAX / 2) { + const N: usize = 4; + let mut t: [u64; N] = [0u64; N]; + let mut c_2: u64; + for i in 0..4 { + let mut c: u64 = 0u64; + for j in 0..4 { + (t[j], c) = mac(t[j], self.0[j], rhs.0[i], c); + } + c_2 = c; + + let m = t[0].wrapping_mul(INV); + (_, c) = macx(t[0], m, $modulus.0[0]); + + for j in 1..4 { + (t[j - 1], c) = mac(t[j], m, $modulus.0[j], c); + } + (t[N - 1], _) = adc(c_2, c, 0); + } + + if bigint_geq(&t, &$modulus.0) { + let mut borrow = 0; + (t[0], borrow) = sbb(t[0], $modulus.0[0], borrow); + (t[1], borrow) = sbb(t[1], $modulus.0[1], borrow); + (t[2], borrow) = sbb(t[2], $modulus.0[2], borrow); + (t[3], borrow) = sbb(t[3], $modulus.0[3], borrow); + } + $field(t) + } else { + // Schoolbook multiplication - let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0); - let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry); - let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry); - let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry); + let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0); + let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry); + let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry); + let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry); - let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0); - let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry); - let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry); - let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry); + let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0); + let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry); + let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry); + let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry); - let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0); - let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry); - let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry); - let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry); + let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0); + let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry); + let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry); + let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry); - let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0); - let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry); - let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry); - let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry); + let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0); + let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry); + let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry); + let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry); - $field::montgomery_reduce(&[r0, r1, r2, r3, r4, r5, r6, r7]) + $field::montgomery_reduce(&[r0, r1, r2, r3, r4, r5, r6, r7]) + } } /// Subtracts `rhs` from `self`, returning the result. diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index db5888a2..bf7ffc5a 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb}; use crate::extend_field_legendre; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 4f4277c3..ba7351ca 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb}; use crate::extend_field_legendre; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index 43824aa9..31005371 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb}; use crate::extend_field_legendre; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index d17120a6..63a0d0a6 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb}; use crate::extend_field_legendre; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use core::fmt;