diff --git a/src/lib.nr b/src/lib.nr index b2587dea..32026120 100644 --- a/src/lib.nr +++ b/src/lib.nr @@ -20,7 +20,7 @@ pub struct BigNum { **/ // // trait BigNumParamsTrait where Params: RuntimeBigNumParamsTrait, RuntimeBigNumInstance: RuntimeBigNumInstanceTrait> { -trait BigNumParamsTrait where Self: RuntimeBigNumParamsTrait { +pub trait BigNumParamsTrait where Self: RuntimeBigNumParamsTrait { fn get_instance() -> RuntimeBigNumInstance where Self: RuntimeBigNumParamsTrait;// ; @@ -32,7 +32,7 @@ trait BigNumParamsTrait where Self: RuntimeBigNumParamsTrait { fn has_multiplicative_inverse() -> bool { true } } -trait BigNumTrait where BigNumTrait: std::ops::Add + std::ops::Sub + std::ops::Mul + std::ops::Div + std::ops::Eq + RuntimeBigNumTrait { +pub trait BigNumTrait where BigNumTrait: std::ops::Add + std::ops::Sub + std::ops::Mul + std::ops::Div + std::ops::Eq + RuntimeBigNumTrait { // TODO: this crashes the compiler? v0.32 // fn default() -> Self { std::default::Default::default () } fn from(limbs: [Field]) -> Self { RuntimeBigNumTrait::from(limbs) } @@ -41,6 +41,7 @@ trait BigNumTrait where BigNumTrait: std::ops::Add + std::ops::Sub + std::ops::M fn modulus() -> Self; fn modulus_bits(self) -> u32; fn num_limbs(self) -> u32; + fn derive_from_seed(seed: [u8; SeedBytes]) -> Self; unconstrained fn __derive_from_seed(seed: [u8; SeedBytes]) -> Self; unconstrained fn __pow(self, exponent: Self) -> Self; unconstrained fn __neg(self) -> Self; @@ -72,6 +73,7 @@ trait BigNumTrait where BigNumTrait: std::ops::Add + std::ops::Sub + std::ops::M fn set_limb(&mut self, idx: u32, value: Field) { RuntimeBigNumTrait::set_limb(self, idx, value) } fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self { RuntimeBigNumTrait::conditional_select(lhs, rhs, predicate) } fn to_le_bytes(self) -> [u8; X] { RuntimeBigNumTrait::to_le_bytes(self) } + unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option; } impl BigNumTrait for BigNum where Params: BigNumParamsTrait + RuntimeBigNumParamsTrait { @@ -117,6 +119,9 @@ impl BigNumTrait for BigNum where Params: BigNumP Params::get_instance().__derive_from_seed(seed) } + fn derive_from_seed(seed: [u8; SeedBytes]) -> Self { + Params::get_instance().derive_from_seed(seed) + } unconstrained fn __neg(self) -> Self { Params::get_instance().__neg(self) } @@ -265,6 +270,10 @@ impl BigNumTrait for BigNum where Params: BigNumP fn umod(self, divisor: Self) -> Self { Params::get_instance().umod(self, divisor) } + + unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option { + Params::get_instance().__tonelli_shanks_sqrt(self) + } } impl BigNum where Params: BigNumParamsTrait + RuntimeBigNumParamsTrait {} diff --git a/src/runtime_bignum.nr b/src/runtime_bignum.nr index 27e3fa80..98c708d2 100644 --- a/src/runtime_bignum.nr +++ b/src/runtime_bignum.nr @@ -36,6 +36,7 @@ comptime global BARRETT_REDUCTION_OVERFLOW_BITS: u32 = 4; pub trait BigNumInstanceTrait where BN: BigNumTrait { fn modulus(self) -> BN; fn eq(self, lhs: BN, rhs: BN) -> bool; + fn derive_from_seed(self, seed: [u8; SeedBytes]) -> BN; unconstrained fn __derive_from_seed(self, seed: [u8; SeedBytes]) -> BN; unconstrained fn __neg(self, val: BN) -> BN; unconstrained fn __add(self, lhs: BN, rhs: BN) -> BN; @@ -78,6 +79,7 @@ pub trait BigNumInstanceTrait where BN: BigNumTrait { fn udiv(self, numerator: BN, divisor: BN) -> BN; fn umod(self, numerator: BN, divisor: BN) -> BN; +unconstrained fn __tonelli_shanks_sqrt(self, input: BN) -> std::option::Option; } /** @@ -287,8 +289,279 @@ impl BigNum where Params: BigNumParamsTrait { } } +impl BigNumInstance where Params: BigNumParamsTrait { + + /** + * @brief compute the log of the size of the primitive root + * @details find the maximum value k where x^k = 1, where x = primitive root + * This is needed for our Tonelli-Shanks sqrt algorithm + **/ + unconstrained fn primitive_root_log_size(self) -> u32 { + let mut target: U60Repr = self.modulus_u60 - U60Repr::one(); + let mut result: u32 = 0; + for _ in 0..Params::modulus_bits() { + let lsb_is_one = (target.limbs[0] & 1) == 1; + if (!lsb_is_one) { + result += 1; + target.shr1(); + } else { + break; + } + } + result + } + + /** + * @brief inner loop fn for `find_multiplive_generator` + * @details recursive function to get around the lack of a `while` keyword + **/ + unconstrained fn recursively_find_multiplicative_generator( + self, + target: BigNum, + p_minus_one_over_two: BigNum + ) -> (bool, BigNum) { + let exped = (self.__pow(target, p_minus_one_over_two)); + let found = exped.__eq(self.__neg(BigNum::one())); + let mut result: (bool, BigNum) = (found, target); + if (!found) { + let _target = unsafe { + self.__add(target, BigNum::one()) + }; + result = self.recursively_find_multiplicative_generator(_target, p_minus_one_over_two); + } + result + } + + /** + * @brief find multiplicative generator `g` where `g` is the smallest value that is not a quadratic residue + * i.e. smallest g where g^2 = -1 + * @note WARNING if multiplicative generator does not exist, this function will enter an infinite loop! + **/ + unconstrained fn multiplicative_generator(self) -> BigNum { + let mut target = BigNum::one(); + let p_minus_one_over_two: U60Repr = (self.modulus_u60 - U60Repr::one()).shr(1); + let p_minus_one_over_two: BigNum = BigNum::from_array(U60Repr::into(p_minus_one_over_two)); + + let (_, target) = self.recursively_find_multiplicative_generator(target, p_minus_one_over_two); + target + } + + unconstrained fn __tonelli_shanks_sqrt_inner_loop_check(self, t2m: BigNum, i: u32) -> u32 { + let is_one = t2m.__eq(BigNum::one()); + let mut result = i; + if (!is_one) { + let t2m = self.__mul(t2m, t2m); + let i = i + 1; + result = self.__tonelli_shanks_sqrt_inner_loop_check(t2m, i); + } + result + } +} impl BigNumInstanceTrait> for BigNumInstance where Params: BigNumParamsTrait { + /** + * @brief compute a modular square root using the Tonelli-Shanks algorithm + * @details only use for prime fields! Function may infinite loop if used for non-prime fields + * @note this is unconstrained fn. To constrain a square root, validate that output^2 = self + * TODO: create fn that constrains nonexistence of square root (i.e. find x where x^2 = -self) + **/ + unconstrained fn __tonelli_shanks_sqrt( + self, + input: BigNum + ) -> std::option::Option> { + // Tonelli-shanks algorithm begins by finding a field element Q and integer S, + // such that (p - 1) = Q.2^{s} + + // We can compute the square root of a, by considering a^{(Q + 1) / 2} = R + // Once we have found such an R, we have + // R^{2} = a^{Q + 1} = a^{Q}a + // If a^{Q} = 1, we have found our square root. + // Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity. + // This is because t^{2^{s-1}} = a^{Q.2^{s-1}}. + // We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2} + // From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1 + // i.e. t^{2^{s-1}} = 1 + + // To proceed with computing our square root, we want to transform t into a smaller subgroup, + // specifically, the (s-2)'th roots of unity. + // We do this by finding some value b,such that + // (t.b^2)^{2^{s-2}} = 1 and R' = R.b + // Finding such a b is trivial, because from Euler's criterion, we know that, + // for any quadratic non-residue z, z^{(p - 1) / 2} = -1 + // i.e. z^{Q.2^{s-1}} = -1 + // => z^Q is a 2^{s-1}'th root of -1 + // => z^{Q^2} is a 2^{s-2}'th root of -1 + // Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1 + // => t.z^{Q^2} is a 2^{s - 2}'th root of unity. + + // We can iteratively transform t into ever smaller subgroups, until t = 1. + // At each iteration, we need to find a new value for b, which we can obtain + // by repeatedly squaring z^{Q} + + let one: U60Repr = unsafe { + U60Repr::one() + }; + let primitive_root_log_size = self.primitive_root_log_size(); + + let mut Q = (self.modulus_u60 - one).shr(primitive_root_log_size - 1); + let mut Q_minus_one_over_two = (Q - one).shr(2); + let mut Q_minus_one_over_two = BigNum::from_array(U60Repr::into(Q_minus_one_over_two)); + let mut z = self.multiplicative_generator(); // the generator is a non-residue + let mut b = self.__pow(input, Q_minus_one_over_two); + let mut r = self.__mul(input, b); + let mut t = self.__mul(r, b); + + let mut check: BigNum = t; + for _ in 0..primitive_root_log_size - 1 { + check = self.__mul(check, check); + } + let mut found_root = false; + if (check.__eq(BigNum::one()) == false) {} else { + let mut t1 = self.__pow(z, Q_minus_one_over_two); + let mut t2 = self.__mul(t1, z); + let mut c = self.__mul(t2, t1); + let mut m: u32 = primitive_root_log_size; + // tonelli shanks inner 1 + // (if t2m == 1) then skip + // else increase i and square t2m and go again + // algorithm runtime should only be max the number of bits in modulus + let num_bits: u32 = Params::modulus_bits(); + for _ in 0..num_bits { + if (t.__eq(BigNum::one())) { + found_root = true; + break; + } + let mut t2m = t; + // while loop time + let i = self.__tonelli_shanks_sqrt_inner_loop_check(t2m, 0); + let mut j = m - i - 1; + b = c; + + for _ in 0..j { // how big + if (j == 0) { + break; + } + b = self.__mul(b, b); + //j -= 1; + } + c = self.__mul(b, b); + t = self.__mul(t, c); + r = self.__mul(r, b); + m = i; + } + } + let mut result = std::option::Option { _value: r, _is_some: found_root }; + result + } + + /** + * @brief given an input seed, generate a pseudorandom BigNum value + * @details we hash the input seed into `modulus_bits * 2` bits of entropy, + * which is then reduced into a BigNum value + * We use a hash function that can be modelled as a random oracle + * This function *should* produce an output that is a uniformly randomly distributed value modulo BigNum::modulus() + **/ + fn derive_from_seed(self, seed: [u8; SeedBytes]) -> BigNum { + let mut rolling_seed: [u8; SeedBytes + 1] = [0; SeedBytes + 1]; + for i in 0..SeedBytes { + rolling_seed[i] = seed[i]; + assert_eq(rolling_seed[i], seed[i]); + } + + let mut hash_buffer: [u8; N * 2 * 15] = [0; N * 2 * 15]; + + let mut rolling_hash_fields: [Field; (SeedBytes / 31) + 1] = [0; (SeedBytes / 31) + 1]; + let mut seed_ptr = 0; + for i in 0..(SeedBytes / 31) + 1 { + let mut packed: Field = 0; + for _ in 0..31 { + if (seed_ptr < SeedBytes) { + packed *= 256; + packed += seed[seed_ptr] as Field; + seed_ptr += 1; + } + } + rolling_hash_fields[i] = packed; + } + + let compressed = std::hash::poseidon2::Poseidon2::hash(rolling_hash_fields, (SeedBytes / 31) + 1); + let mut rolling_hash: [Field; 2] = [compressed, 0]; + let num_hashes = (30 * N) / 32 + (((30 * N) % 32) != 0) as u32; + for i in 0..num_hashes - 1 { + let hash: Field = std::hash::poseidon2::Poseidon2::hash(rolling_hash, 2); + let hash: [u8; 32] = hash.to_le_bytes(); + for j in 0..30 { + hash_buffer[i * 30 + j] = hash[j]; + } + rolling_hash[1] += 1; + } + { + let hash: Field = std::hash::poseidon2::Poseidon2::hash(rolling_hash, 2); + let hash: [u8; 32] = hash.to_le_bytes(); + let remaining_bytes = 30 * N - (num_hashes - 1) * 30; + for j in 0..remaining_bytes { + hash_buffer[(num_hashes - 1) * 30 + j] = hash[j]; + } + } + + let num_bits = Params::modulus_bits() * 2; + let num_bytes = num_bits / 8 + ((num_bits % 8) != 0) as u32; + + let bits_in_last_byte = num_bits as u8 % 8; + let last_byte_mask = (1 as u8 << bits_in_last_byte) - 1; + hash_buffer[num_bytes - 1] = hash_buffer[num_bytes - 1] & last_byte_mask; + + let num_bigfield_chunks = (2 * N) / (N - 1) + (((2 * N) % (N - 1)) != 0) as u32; + let mut byte_ptr = 0; + + // we want to convert our byte array into bigfield chunks + // each chunk has at most N-1 limbs + // to determine the exact number of chunks, we need the `!=` or `>` operator which is not avaiable when defining array sizes + // so we overestimate at 4 + // e.g. if N = 20, then we have 40 limbs we want to reduce, but each bigfield chunk is 19 limbs, so we need 3 + // if N = 2, we have 4 limbs we want to reduce but each bigfield chunk is only 1 limb, so we need 4 + // max possible number of chunks is therefore 4 + + let mut bigfield_chunks: [[Field; N]; 4] = [[0; N]; 4]; + for k in 0..num_bigfield_chunks { + let mut bigfield_limbs: [Field; N] = [0; N]; + let mut num_filled_bytes = (k * 30); + let mut num_remaining_bytes = num_bytes - num_filled_bytes; + let mut num_remaining_limbs = (num_remaining_bytes / 15) + (num_remaining_bytes % 15 > 0) as u32; + let mut more_than_N_minus_one_limbs = (num_remaining_limbs > (N - 1)) as u32; + let mut num_limbs_in_bigfield = more_than_N_minus_one_limbs * (N - 1) + + num_remaining_limbs * (1 - more_than_N_minus_one_limbs); + + for j in 0..num_limbs_in_bigfield { + let mut limb: Field = 0; + for _ in 0..15 { + let need_more_bytes = (byte_ptr < num_bytes); + let mut byte = hash_buffer[byte_ptr]; + limb *= (256 * need_more_bytes as Field + (1 - need_more_bytes as Field)); + limb += byte as Field * need_more_bytes as Field; + byte_ptr += need_more_bytes as u32; + } + bigfield_limbs[num_limbs_in_bigfield - 1 - j] = limb; + } + bigfield_chunks[num_bigfield_chunks - 1 - k] = bigfield_limbs; + } + + let mut bigfield_rhs_limbs: [Field; N] = [0; N]; + bigfield_rhs_limbs[N-1] = 1; + let bigfield_rhs: BigNum = BigNum::from_array(bigfield_rhs_limbs); + bigfield_rhs.validate_in_range(); + let mut result: BigNum = BigNum::new(); + + for i in 0..num_bigfield_chunks { + let bigfield_limbs = bigfield_chunks[i]; + let bigfield_lhs: BigNum = BigNum::from_array(bigfield_limbs); + + result = self.mul(result, bigfield_rhs); + result = self.add(result, bigfield_lhs); + } + result + } + fn modulus(self) -> BigNum { BigNum { limbs: self.modulus } } @@ -414,14 +687,16 @@ impl BigNumInstanceTrait> for BigNumInstan ) { // use an unconstrained function to compute the value of the quotient let (quotient, _, borrow_flags): (BigNum, BigNum, [Field; 2 * N]) = unsafe { - self.__compute_quadratic_expression_with_borrow_flags( - lhs_terms, - lhs_flags, - rhs_terms, - rhs_flags, - linear_terms, - linear_flags - ) + unsafe { + self.__compute_quadratic_expression_with_borrow_flags( + lhs_terms, + lhs_flags, + rhs_terms, + rhs_flags, + linear_terms, + linear_flags + ) + } }; // constrain the quotient to be in the range [0, ..., 2^{m} - 1], where `m` is log2(modulus) rounded up. // Additionally, validate quotient limbs are also in the range [0, ..., 2^{120} - 1] @@ -700,7 +975,9 @@ impl BigNumInstanceTrait> for BigNumInstan // a - b = r // p + a - b - r = 0 let (result, carry_flags, borrow_flags, underflow) = unsafe { - self.__sub_with_flags(lhs, rhs) + unsafe { + self.__sub_with_flags(lhs, rhs) + } }; result.validate_in_range(); let modulus = self.modulus; @@ -835,7 +1112,21 @@ impl BigNumInstance where Params: BigNumParamsTra } unconstrained fn __derive_from_seed_impl(self, seed: [u8; SeedBytes]) -> BigNum { - let mut rolling_seed = seed; + let mut rolling_hash_fields: [Field; (SeedBytes / 31) + 1] = [0; (SeedBytes / 31) + 1]; + let mut seed_ptr = 0; + for i in 0..(SeedBytes / 31) + 1 { + let mut packed: Field = 0; + for _ in 0..31 { + if (seed_ptr < SeedBytes) { + packed *= 256; + packed += seed[seed_ptr] as Field; + seed_ptr += 1; + } + } + rolling_hash_fields[i] = packed; + } + let compressed = std::hash::poseidon2::Poseidon2::hash(rolling_hash_fields, (SeedBytes / 31) + 1); + let mut rolling_hash: [Field; 2] = [compressed, 0]; let mut to_reduce: [Field; 2 * N] = [0; 2 * N]; @@ -852,7 +1143,8 @@ impl BigNumInstance where Params: BigNumParamsTra } for i in 0..(N - 1) { - let hash: [u8; 32] = std::hash::sha256(rolling_seed); + let hash = std::hash::poseidon2::Poseidon2::hash(rolling_hash, 2); + let hash : [u8; 30] = hash.to_le_bytes(); let mut lo: Field = 0; let mut hi: Field = 0; for j in 0..15 { @@ -866,11 +1158,12 @@ impl BigNumInstance where Params: BigNumParamsTra } to_reduce[2 * i] = lo; to_reduce[2 * i + 1] = hi; - rolling_seed[0] += 1; + rolling_hash[1] += 1; } { - let hash: [u8; 32] = std::hash::sha256(rolling_seed); + let hash = std::hash::poseidon2::Poseidon2::hash(rolling_hash, 2); + let hash : [u8; 30] = hash.to_le_bytes(); let mut hi: Field = 0; for j in 0..(last_limb_bytes - 1) { hi *= 256; @@ -883,7 +1176,6 @@ impl BigNumInstance where Params: BigNumParamsTra hi += last_bits as Field; to_reduce[2 * N - 2] = hi; } - let (_, remainder) = __barrett_reduction( to_reduce, self.redc_param, @@ -943,8 +1235,9 @@ impl BigNumInstance where Params: BigNumParamsTra mul[i + j] += lhs.limbs[i] * rhs.limbs[j]; } } + let to_reduce = split_bits::__normalize_limbs(mul, 2 * N); let (q, r) = __barrett_reduction( - split_bits::__normalize_limbs(mul, 2 * N), + to_reduce, self.redc_param, Params::modulus_bits(), self.modulus, @@ -1268,7 +1561,7 @@ impl BigNumInstance where Params: BigNumParamsTra if (flags[i]) { for j in 0..N { sum[j] = sum[j] + modulus2[j] - x[i].limbs[j]; - assert(x[i].limbs[j].lt(modulus2[j])); + // assert(x[i].limbs[j].lt(modulus2[j])); } } else { for j in 0..N { @@ -1341,7 +1634,8 @@ impl BigNumInstance where Params: BigNumParamsTra linear_flags ); let mut relation_result: [Field; 2 * N] = split_bits::__normalize_limbs(mulout, 2 * N); - + // size 4 + // a[3] * b[3] = a[6] = 7 // TODO: ugly! Will fail if input slice is empty let k = Params::modulus_bits(); @@ -1524,7 +1818,7 @@ unconstrained fn __barrett_reduction( mulout[i + j] += x[i] * redc_param[j]; } } - mulout = split_bits::__normalize_limbs(mulout, 3 * N - 1); + mulout = split_bits::__normalize_limbs(mulout, 3 * N - 2); let mulout_u60: U60Repr = U60Repr::new(mulout); // When we apply the barrett reduction, the maximum value of the output will be @@ -1557,7 +1851,6 @@ unconstrained fn __barrett_reduction( quotient_mul_modulus[i + j] += partial_quotient[i] * modulus[j]; } } - for i in 0..(N + N) { let (lo, hi) = split_bits::split_120_bits(quotient_mul_modulus[i]); quotient_mul_modulus_normalized[i] = lo; @@ -1576,7 +1869,6 @@ unconstrained fn __barrett_reduction( remainder_u60 = remainder_u60 - modulus_u60; quotient_u60.increment(); } else {} - let q: [Field; N] = U60Repr::into(quotient_u60); let r: [Field; N] = U60Repr::into(remainder_u60); diff --git a/src/runtime_bignum_test.nr b/src/runtime_bignum_test.nr index 5cffc72c..e9d67e17 100644 --- a/src/runtime_bignum_test.nr +++ b/src/runtime_bignum_test.nr @@ -1,15 +1,67 @@ use crate::BigNum; use crate::runtime_bignum::BigNumInstance; use crate::runtime_bignum::BigNumParamsTrait; +use crate::BigNumParamsTrait as NotRuntimeBigNumParamsTrait; + use crate::utils::u60_representation::U60Repr; use crate::fields::bn254Fq::BNParams as BNParams; use crate::fields::secp256k1Fq::Secp256k1_Fq_Params; use crate::fields::bls12_381Fq::BLS12_381_Fq_Params; +use crate::fields::bls12_381Fr::BLS12_381_Fr_Params; +use crate::fields::bls12_377Fq::BLS12_377_Fq_Params; +use crate::fields::bls12_377Fr::BLS12_377_Fr_Params; + +struct Test2048Params {} + +// See https://github.com/noir-lang/noir/issues/6172 +#[test] +fn silence_warning() { + let _ = Test2048Params {}; +} + +impl BigNumParamsTrait<18> for Test2048Params { + fn modulus_bits() -> u32 { + 2048 + } +} +impl NotRuntimeBigNumParamsTrait<18> for Test2048Params { + fn modulus_bits() -> u32 { + 2048 + } -// + fn get_instance() -> BigNumInstance<18, Test2048Params> { + let modulus: [Field; 18] = [ + 0x0000000000000000000000000000000000c0a197a5ae0fcdceb052c9732614fe, + 0x0000000000000000000000000000000000656ae034423283422243918ab83be3, + 0x00000000000000000000000000000000006bf590da48a7c1070b7d5aabaac678, + 0x00000000000000000000000000000000000cce39f530238b606f24b296e2bda9, + 0x000000000000000000000000000000000001e1fef9bb9c1c3ead98f226f1bfa0, + 0x0000000000000000000000000000000000ad8c1c816e12e0ed1379055e373abf, + 0x0000000000000000000000000000000000cebe80e474f753aa9d1461c435123d, + 0x0000000000000000000000000000000000aee5a18ceedef88d115a8b93c167ad, + 0x0000000000000000000000000000000000268ba83c4a65c4307427fc495d9e44, + 0x0000000000000000000000000000000000dd2777926848667b7df79f342639d4, + 0x0000000000000000000000000000000000f455074c96855ca0068668efe7da3d, + 0x00000000000000000000000000000000005ddba6b30bbc168bfb3a1225f27d65, + 0x0000000000000000000000000000000000591fec484f36707524133bcd6f4258, + 0x000000000000000000000000000000000059641b756766aeebe66781dd01d062, + 0x000000000000000000000000000000000058bc5eaff4b165e142bf9e2480eebb, + 0x0000000000000000000000000000000000667a3964f08e06df772ce64b229a72, + 0x00000000000000000000000000000000009c1fdb18907711bfe3e3c1cf918395, + 0x00000000000000000000000000000000000000000000000000000000000000b8 + ]; + let redc_param: [Field; 18] = [ + 0x1697def7100cd5cf8d890b4ef2ec3f, 0x765ba8304214dac764d3f4adc31859, 0x8404bd14d927ea230e60d4bebf9406, 0xc4d53a23bacc251ecbfc4b7ba5a0b4, 0x3eaf3499474a6f5b2fff83f1259c87, 0xbff4c737b97281f1a5f2384a8c16d9, 0x1b4cf2f55358476b53237829990555, 0xe7a804e8eacfe3a2a5673bc3885b86, 0xabadeae4282906c817adf70eab4ae1, 0x66f7df257fe2bf27f0809aceed9b0e, 0xd90fb7428901b8bed11f6b81e36bf1, 0x36e6ba885c60b7024c563605df7e07, 0x2b7c58d2fb5d2c8478963ae6d4a44f, 0x6ee761de26635f114ccc3f7d74f855, 0x3fb726a10cf2220897513f05243de3, 0x43a26bbd732496eb4d828591b8056e, 0xf4e42304e60fb3a54fca735499f2cf, 0x162f + ]; + BigNumInstance::new(modulus, redc_param) + } +} +/** + * @brief this example was failing - sanity test to validate it now works + **/ #[test] -fn test_bls() { +fn test_bls_reduction() { let X1: BigNum<4, BLS12_381_Fq_Params> = BigNum { limbs: [ 0x55e83ff97a1aeffb3af00adb22c6bb, 0x8c4f9774b905a14e3a3f171bac586c, 0xa73197d7942695638c4fa9ac0fc368, 0x17f1d3 @@ -28,6 +80,10 @@ fn test_bls() { }; XX_mul_3.validate_in_field(); } + +/** + * @brief experimenting with macro madness and code generation to make some tests that apply to multiple BigNum parametrisations! + **/ comptime fn make_test(f: StructDefinition, N: Quoted, typ: Quoted) -> Quoted { let k = f.name(); quote{ @@ -54,7 +110,78 @@ fn test_add() { } assert(e.limbs[0] == 2); } - + +#[test] +fn test_sub() { + let bn = $typ::get_instance(); + // 0 - 1 should equal p - 1 + let mut a: BigNum<$N, $typ> = BigNum::new(); + let mut b: BigNum<$N, $typ> = BigNum::one(); + let mut expected: BigNum<$N, $typ> = bn.modulus(); + expected.limbs[0] -= 1; // p - 1 + + let result = bn.sub(a, b); + assert(bn.eq(result, expected)); +} + + +#[test] +fn test_sub_modulus_limit() { + let instance = $typ::get_instance(); + // if we underflow, maximum result should be ... + // 0 - 1 = o-1 + // 0 - p = 0 + let mut a: BigNum<$N, $typ> = BigNum::new(); + let mut b: BigNum<$N, $typ> = instance.modulus(); + let mut expected: BigNum<$N, $typ> = BigNum::new(); + let result = instance.sub(a, b); + assert(instance.eq(result, expected)); +} + + +#[test(should_fail_with = "call to assert_max_bit_size")] +fn test_sub_modulus_underflow() { + let instance = $typ::get_instance(); + + // 0 - (p + 1) is smaller than p and should produce unsatisfiable constraints + let mut a: BigNum<$N, $typ> = BigNum::new(); + let mut b: BigNum<$N, $typ> = instance.modulus(); + b.limbs[0] += 1; + let mut expected: BigNum<$N, $typ> = BigNum::one(); + + let result = instance.sub(a, b); + + assert(instance.eq(result, expected)); +} + +#[test] +fn test_add_modulus_limit() { + let instance = $typ::get_instance(); + // p + 2^{modulus_bits()} - 1 should be the maximum allowed value fed into an add operation + // when adding, if the result overflows the modulus, we conditionally subtract the modulus, producing 2^{254} - 1 + // this is the largest value that will satisfy the range check applied when constructing a bignum + let p : U60Repr<$N, 2> = U60Repr::from(instance.modulus().limbs); + let one = unsafe{ U60Repr::one() }; + + let a: BigNum<$N, $typ> = unsafe{ BigNum { limbs: U60Repr::into(p) } }; + let mut two_pow_modulus_bits_minus_one: U60Repr<$N, 2> = unsafe{ one.shl(a.modulus_bits()) - one }; + let b: BigNum<$N, $typ> = BigNum { limbs: U60Repr::into(two_pow_modulus_bits_minus_one) }; + let result = instance.add(a, b); + assert(instance.eq(result, b)); +} + +#[test(should_fail_with = "call to assert_max_bit_size")] +fn test_add_modulus_overflow() { + + let instance = $typ::get_instance(); + let p : U60Repr<$N, 2> = U60Repr::from(instance.modulus().limbs); + let one = unsafe{ U60Repr::one() }; + let a: BigNum<$N, $typ> = unsafe{ BigNum { limbs: U60Repr::into(p + one) } }; + let mut two_pow_modulus_bits_minus_one: U60Repr<$N, 2> = unsafe{ one.shl(a.modulus_bits()) - one }; + let b: BigNum<$N, $typ> = BigNum { limbs: U60Repr::into(two_pow_modulus_bits_minus_one) }; + let result = instance.add(a, b); + assert(instance.eq(result, b)); +} #[test] fn test_mul() { @@ -176,6 +303,42 @@ fn assert_is_not_equal_overloaded_fail() { let b_plus_modulus: BigNum<$N, $typ> = BigNum { limbs: U60Repr::into(t1 + t2) }; bn.assert_is_not_equal(a_plus_modulus, b_plus_modulus); } + +#[test] +fn test_derive() +{ + let bn = $typ ::get_instance(); + let a: BigNum<$N, $typ> = bn.derive_from_seed("hello".as_bytes()); + let b: BigNum<$N, $typ> = unsafe { + bn.__derive_from_seed("hello".as_bytes()) + }; + assert(bn.eq(a, b)); +} + +#[test] +fn test_eq() { + let bn = $typ ::get_instance(); + let a: BigNum<$N, $typ> = unsafe { + bn.__derive_from_seed([1, 2, 3, 4]) + }; + let b: BigNum<$N, $typ> = unsafe { + bn.__derive_from_seed([1, 2, 3, 4]) + }; + let c: BigNum<$N, $typ> = unsafe { + bn.__derive_from_seed([2, 2, 3, 4]) + }; + + let modulus: BigNum<$N, $typ> = bn.modulus(); + let t0: U60Repr<$N, 2> = (U60Repr::from(modulus.limbs)); + let t1: U60Repr<$N, 2> = (U60Repr::from(b.limbs)); + let b_plus_modulus: BigNum<$N, $typ> = BigNum { limbs: U60Repr::into(t0 + t1) }; + + assert(bn.eq(a, b) == true); + assert(bn.eq(a, b_plus_modulus) == true); + assert(bn.eq(c, b) == false); + assert(bn.eq(c, a) == false); +} + } } } @@ -187,130 +350,23 @@ pub struct Secp256K1FqTests{} #[make_test(quote{4},quote{BLS12_381_Fq_Params})] pub struct BLS12_381FqTests{} -struct Test2048Params {} +#[make_test(quote{18},quote{Test2048Params})] +pub struct Test2048Tests{} -// See https://github.com/noir-lang/noir/issues/6172 -#[test] -fn silence_warning() { - let _ = Test2048Params {}; -} +#[make_test(quote{3},quote{BLS12_381_Fr_Params})] +pub struct BLS12_381_Fr_ParamsTests{} -impl BigNumParamsTrait<18> for Test2048Params { - fn modulus_bits() -> u32 { - 2048 - } -} +#[make_test(quote{4},quote{BLS12_377_Fq_Params})] +pub struct BLS12_377_Fq_ParamsTests{} -fn get_2048_BN_instance() -> BigNumInstance<18, Test2048Params> { - let modulus: [Field; 18] = [ - 0x0000000000000000000000000000000000c0a197a5ae0fcdceb052c9732614fe, - 0x0000000000000000000000000000000000656ae034423283422243918ab83be3, - 0x00000000000000000000000000000000006bf590da48a7c1070b7d5aabaac678, - 0x00000000000000000000000000000000000cce39f530238b606f24b296e2bda9, - 0x000000000000000000000000000000000001e1fef9bb9c1c3ead98f226f1bfa0, - 0x0000000000000000000000000000000000ad8c1c816e12e0ed1379055e373abf, - 0x0000000000000000000000000000000000cebe80e474f753aa9d1461c435123d, - 0x0000000000000000000000000000000000aee5a18ceedef88d115a8b93c167ad, - 0x0000000000000000000000000000000000268ba83c4a65c4307427fc495d9e44, - 0x0000000000000000000000000000000000dd2777926848667b7df79f342639d4, - 0x0000000000000000000000000000000000f455074c96855ca0068668efe7da3d, - 0x00000000000000000000000000000000005ddba6b30bbc168bfb3a1225f27d65, - 0x0000000000000000000000000000000000591fec484f36707524133bcd6f4258, - 0x000000000000000000000000000000000059641b756766aeebe66781dd01d062, - 0x000000000000000000000000000000000058bc5eaff4b165e142bf9e2480eebb, - 0x0000000000000000000000000000000000667a3964f08e06df772ce64b229a72, - 0x00000000000000000000000000000000009c1fdb18907711bfe3e3c1cf918395, - 0x00000000000000000000000000000000000000000000000000000000000000b8 - ]; - let redc_param: [Field; 18] = [ - 0x1697def7100cd5cf8d890b4ef2ec3f, 0x765ba8304214dac764d3f4adc31859, 0x8404bd14d927ea230e60d4bebf9406, 0xc4d53a23bacc251ecbfc4b7ba5a0b4, 0x3eaf3499474a6f5b2fff83f1259c87, 0xbff4c737b97281f1a5f2384a8c16d9, 0x1b4cf2f55358476b53237829990555, 0xe7a804e8eacfe3a2a5673bc3885b86, 0xabadeae4282906c817adf70eab4ae1, 0x66f7df257fe2bf27f0809aceed9b0e, 0xd90fb7428901b8bed11f6b81e36bf1, 0x36e6ba885c60b7024c563605df7e07, 0x2b7c58d2fb5d2c8478963ae6d4a44f, 0x6ee761de26635f114ccc3f7d74f855, 0x3fb726a10cf2220897513f05243de3, 0x43a26bbd732496eb4d828591b8056e, 0xf4e42304e60fb3a54fca735499f2cf, 0x162f - ]; - BigNumInstance::new(modulus, redc_param) -} +#[make_test(quote{3},quote{BLS12_377_Fr_Params})] +pub struct BLS12_377_Fr_ParamsTests{} type Fq = BigNum<3, BNParams>; -// type FqInstance = BigNumInstance<3, BNParams>; -// type Fqq = BigNum<18, Test2048Params>; -// type FqqInstance = BigNumInstance<18, Test2048Params>; - -fn test_derive(BNInstance: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - BNInstance.__derive_from_seed("hello".as_bytes()) - }; - let b: BigNum = unsafe { - BNInstance.__derive_from_seed("hello".as_bytes()) - }; - assert(BNInstance.eq(a, b)); -} - -fn test_eq(BNInstance: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - BNInstance.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - BNInstance.__derive_from_seed([1, 2, 3, 4]) - }; - let c: BigNum = unsafe { - BNInstance.__derive_from_seed([2, 2, 3, 4]) - }; - - let modulus: BigNum = BNInstance.modulus(); - let t0: U60Repr = (U60Repr::from(modulus.limbs)); - let t1: U60Repr = (U60Repr::from(b.limbs)); - let b_plus_modulus: BigNum = BigNum { limbs: U60Repr::into(t0 + t1) }; - - assert(BNInstance.eq(a, b) == true); - assert(BNInstance.eq(a, b_plus_modulus) == true); - assert(BNInstance.eq(c, b) == false); - assert(BNInstance.eq(c, a) == false); -} // 98760 // 99689 // 929 gates for a 2048 bit mul -fn test_mul(BNInstance: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - BNInstance.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - BNInstance.__derive_from_seed([4, 5, 6, 7]) - }; - - let c = BNInstance.mul(BNInstance.add(a, b), BNInstance.add(a, b)); - let d = BNInstance.add( - BNInstance.add( - BNInstance.add(BNInstance.mul(a, a), BNInstance.mul(b, b)), - BNInstance.mul(a, b) - ), - BNInstance.mul(a, b) - ); - assert(BNInstance.eq(c, d)); -} - -fn test_add(bn: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - bn.__derive_from_seed([4, 5, 6, 7]) - }; - let one: BigNum = BigNum::one(); - a.validate_in_range(); - bn.validate_in_field(a); - b.validate_in_range(); - bn.validate_in_field(b); - - let mut c = bn.add(a, b); - c = bn.add(c, c); - let d = bn.mul(bn.add(a, b), bn.add(one, one)); - assert(bn.eq(c, d)); - - let e = bn.add(one, one); - for i in 1..N { - assert(e.limbs[i] == 0); - } - assert(e.limbs[0] == 2); -} fn test_div(bn: BigNumInstance) where Params: BigNumParamsTrait { let a: BigNum = unsafe { @@ -340,187 +396,6 @@ fn test_invmod(bn: BigNumInstance) where Params: } } -fn assert_is_not_equal(bn: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - bn.__derive_from_seed([4, 5, 6, 7]) - }; - - bn.assert_is_not_equal(a, b); -} - -fn assert_is_not_equal_fail(bn: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - - bn.assert_is_not_equal(a, b); -} - -fn assert_is_not_equal_overloaded_lhs_fail(bn: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - - let modulus = bn.modulus(); - - let t0: U60Repr = U60Repr::from(a.limbs); - let t1: U60Repr = U60Repr::from(modulus.limbs); - let a_plus_modulus: BigNum = BigNum { limbs: U60Repr::into(t0 + t1) }; - bn.assert_is_not_equal(a_plus_modulus, b); -} - -fn assert_is_not_equal_overloaded_rhs_fail(bn: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - - let modulus = bn.modulus(); - - let t0: U60Repr = U60Repr::from(b.limbs); - let t1: U60Repr = U60Repr::from(modulus.limbs); - let b_plus_modulus: BigNum = BigNum { limbs: U60Repr::into(t0 + t1) }; - bn.assert_is_not_equal(a, b_plus_modulus); -} - -fn assert_is_not_equal_overloaded_fail(bn: BigNumInstance) where Params: BigNumParamsTrait { - let a: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - let b: BigNum = unsafe { - bn.__derive_from_seed([1, 2, 3, 4]) - }; - - let modulus = bn.modulus(); - - let t0: U60Repr = U60Repr::from(a.limbs); - let t1: U60Repr = U60Repr::from(b.limbs); - let t2: U60Repr = U60Repr::from(modulus.limbs); - let a_plus_modulus: BigNum = BigNum { limbs: U60Repr::into(t0 + t2) }; - let b_plus_modulus: BigNum = BigNum { limbs: U60Repr::into(t1 + t2) }; - bn.assert_is_not_equal(a_plus_modulus, b_plus_modulus); -} - -#[test] -fn test_derive_bn() { - test_derive(BNParams::get_instance()); -} - -// MNT6753FqParams -#[test] -fn test_eq_BN() { - let instance = BNParams::get_instance(); - test_eq(instance); -} -#[test] -fn test_add_BN() { - let instance = BNParams::get_instance(); - - let mut a: Fq = instance.modulus(); - let mut b: Fq = instance.modulus(); - let mut expected: Fq = instance.modulus(); - - a.limbs[0] -= 1; - b.limbs[0] -= 1; - expected.limbs[0] -= 2; - - let result = instance.add(a, b); - assert(instance.eq(result, expected)); -} - -#[test] -fn test_sub_test_BN() { - let instance = BNParams::get_instance(); - // 0 - 1 should equal p - 1 - let mut a: Fq = BigNum::new(); - let mut b: Fq = BigNum::one(); - let mut expected: Fq = instance.modulus(); - expected.limbs[0] -= 1; // p - 1 - - let result = instance.sub(a, b); - assert(instance.eq(result, expected)); -} - -#[test] -fn test_sub_modulus_limit() { - let instance = BNParams::get_instance(); - // if we underflow, maximum result should be ... - // 0 - 1 = o-1 - // 0 - p = 0 - let mut a: Fq = BigNum::new(); - let mut b: Fq = instance.modulus(); - let mut expected: Fq = BigNum::new(); - - let result = instance.sub(a, b); - assert(instance.eq(result, expected)); -} - -#[test(should_fail_with = "call to assert_max_bit_size")] -fn test_sub_modulus_underflow() { - let instance = BNParams::get_instance(); - - // 0 - (p + 1) is smaller than p and should produce unsatisfiable constraints - let mut a: Fq = BigNum::new(); - let mut b: Fq = instance.modulus(); - b.limbs[0] += 1; - let mut expected: Fq = BigNum::one(); - - let result = instance.sub(a, b); - - assert(instance.eq(result, expected)); -} - -#[test] -fn test_add_modulus_limit() { - let instance = BNParams::get_instance(); - // p + 2^{254} - 1 should be the maximum allowed value fed into an add operation - // when adding, if the result overflows the modulus, we conditionally subtract the modulus, producing 2^{254} - 1 - // this is the largest value that will satisfy the range check applied when constructing a bignum - let p : U60Repr<3, 2> = U60Repr::from(instance.modulus().limbs); - let two_pow_254_minus_1: U60Repr<3, 2> = U60Repr::from([0xffffffffffffffffffffffffffffff, 0xffffffffffffffffffffffffffffff, 0x3fff]); - let a: Fq = BigNum { limbs: U60Repr::into(p) }; - let b: Fq = BigNum { limbs: U60Repr::into(two_pow_254_minus_1) }; - let result = instance.add(a, b); - assert(instance.eq(result, b)); -} - -#[test(should_fail_with = "call to assert_max_bit_size")] -fn test_add_modulus_overflow() { - let instance = BNParams::get_instance(); - //(2^{254} - 1) + (p - 1) = 2^{254} + p - // after subtracting modulus, result is 2^{254} will does not satisfy the range check applied when constructing a BigNum - let p : U60Repr<3, 2> = U60Repr::from(instance.modulus().limbs); - let two_pow_254_minus_1: U60Repr<3, 2> = U60Repr::from([0xffffffffffffffffffffffffffffff, 0xffffffffffffffffffffffffffffff, 0x3fff]); - let one = U60Repr::from([1, 0, 0]); - let a: Fq = BigNum { limbs: U60Repr::into(p + one) }; - let b: Fq = BigNum { limbs: U60Repr::into(two_pow_254_minus_1) }; - let result = instance.add(a, b); - assert(instance.eq(result, b)); -} - -#[test] -fn test_mul_BN() { - let instance = BNParams::get_instance(); - test_mul(instance); -} - -#[test] -fn test_add_BN2() { - let instance = BNParams::get_instance(); - test_add(instance); -} - #[test] fn test_div_BN() { let instance = BNParams::get_instance(); @@ -533,84 +408,6 @@ fn test_invmod_BN() { test_invmod(instance); } -#[test] -fn test_assert_is_not_equal_BN() { - let instance = BNParams::get_instance(); - assert_is_not_equal(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_fail_BN() { - let instance = BNParams::get_instance(); - assert_is_not_equal_fail(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_overloaded_lhs_fail_BN() { - let instance = BNParams::get_instance(); - assert_is_not_equal_overloaded_lhs_fail(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_overloaded_rhs_fail_BN() { - let instance = BNParams::get_instance(); - assert_is_not_equal_overloaded_rhs_fail(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_overloaded_fail_BN() { - let instance = BNParams::get_instance(); - assert_is_not_equal_overloaded_fail(instance); -} - -#[test] -fn test_eq_2048() { - let instance = get_2048_BN_instance(); - test_eq(instance); -} - -#[test] -fn test_mul_2048() { - let instance = get_2048_BN_instance(); - test_mul(instance); -} - -#[test] -fn test_add_2048() { - let instance = get_2048_BN_instance(); - test_add(instance); -} - -#[test] -fn test_assert_is_not_equal_2048() { - let instance = get_2048_BN_instance(); - assert_is_not_equal(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_fail_2048() { - let instance = get_2048_BN_instance(); - assert_is_not_equal_fail(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_overloaded_lhs_fail_2048() { - let instance = get_2048_BN_instance(); - assert_is_not_equal_overloaded_lhs_fail(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_overloaded_rhs_fail_2048() { - let instance = get_2048_BN_instance(); - assert_is_not_equal_overloaded_rhs_fail(instance); -} - -#[test(should_fail_with = "asssert_is_not_equal fail")] -fn test_assert_is_not_equal_overloaded_fail_2048() { - let instance = get_2048_BN_instance(); - assert_is_not_equal_overloaded_fail(instance); -} - // N.B. witness generation times make these tests take ~15 minutes each! Uncomment at your peril // #[test] // fn test_div_2048() { @@ -627,7 +424,7 @@ fn test_assert_is_not_equal_overloaded_fail_2048() { #[test] fn test_2048_bit_quadratic_expression() { - let instance = get_2048_BN_instance(); + let instance = Test2048Params::get_instance(); let a: [Field; 18] = [ 0x000000000000000000000000000000000083684820ff40795b8d9f1be2220cba, 0x0000000000000000000000000000000000d4924fbdc522b07b6cd0ef5508fd66,