diff --git a/mvpoly/src/pbt.rs b/mvpoly/src/pbt.rs index 130c6c9fc4..011258f295 100644 --- a/mvpoly/src/pbt.rs +++ b/mvpoly/src/pbt.rs @@ -19,7 +19,7 @@ use crate::MVPoly; use ark_ff::PrimeField; use rand::{seq::SliceRandom, Rng}; -use std::ops::Neg; +use std::{collections::HashMap, ops::Neg}; pub fn test_mul_by_one>() { let mut rng = o1_utils::tests::make_test_rng(None); @@ -587,3 +587,137 @@ pub fn test_is_constant, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + // Generate max 40 polynomials + let n = rng.gen_range(1..10); + let polys = (0..n) + .map(|_| unsafe { T::random(&mut rng, None) }) + .collect::>(); + let eval1: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let eval2: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let u1 = F::rand(&mut rng); + let u2 = F::rand(&mut rng); + + { + let alpha1 = F::rand(&mut rng); + let alpha2 = F::rand(&mut rng); + + let combined_cross_terms = crate::compute_combined_cross_terms( + polys.clone(), + alpha1, + alpha2, + eval1, + eval2, + u1, + u2, + ); + assert_eq!(combined_cross_terms.len(), D); + } + + // Check that even if zero is given, we get the right number of cross terms + { + let alpha1 = F::zero(); + let alpha2 = F::rand(&mut rng); + let combined_cross_terms = crate::compute_combined_cross_terms( + polys.clone(), + alpha1, + alpha2, + eval1, + eval2, + u1, + u2, + ); + assert_eq!(combined_cross_terms.len(), D); + } + + { + let alpha1 = F::rand(&mut rng); + let alpha2 = F::zero(); + let combined_cross_terms = crate::compute_combined_cross_terms( + polys.clone(), + alpha1, + alpha2, + eval1, + eval2, + u1, + u2, + ); + assert_eq!(combined_cross_terms.len(), D); + } + + { + let alpha1 = F::zero(); + let alpha2 = F::zero(); + let combined_cross_terms = + crate::compute_combined_cross_terms(polys, alpha1, alpha2, eval1, eval2, u1, u2); + assert_eq!(combined_cross_terms.len(), D); + } +} + +pub fn test_compute_combined_cross_terms_right_alpha_null< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p = unsafe { T::random(&mut rng, Some(D - 1)) }; + let eval1: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let eval2: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let u1 = F::rand(&mut rng); + let u2 = F::rand(&mut rng); + let cross_terms = p.compute_cross_terms(&eval1, &eval2, u1, u2); + + let alpha2 = F::zero(); + { + let alpha1 = F::one(); + let combined_cross_terms = crate::compute_combined_cross_terms( + vec![p.clone()], + alpha1, + alpha2, + eval1, + eval2, + u1, + u2, + ); + let mut exp_cross_terms = cross_terms.clone(); + exp_cross_terms.insert(D, F::zero()); + assert_eq!(exp_cross_terms, combined_cross_terms); + } + + { + let alpha1 = F::rand(&mut rng); + let combined_cross_terms = crate::compute_combined_cross_terms( + vec![p.clone()], + alpha1, + alpha2, + eval1, + eval2, + u1, + u2, + ); + // alpha is never used as there is only one polynomial + let mut exp_cross_terms: HashMap = cross_terms.clone(); + exp_cross_terms.insert(D, F::zero()); + assert_eq!(exp_cross_terms, combined_cross_terms); + } + + // Special case with alpha1 = 0 + { + let alpha1 = F::zero(); + let combined_cross_terms = + crate::compute_combined_cross_terms(vec![p], alpha1, alpha2, eval1, eval2, u1, u2); + let mut zeroes: HashMap = HashMap::new(); + for i in 1..=D { + zeroes.insert(i, F::zero()); + } + assert_eq!(combined_cross_terms, zeroes); + } +} diff --git a/mvpoly/tests/monomials.rs b/mvpoly/tests/monomials.rs index 902a3fcfbc..ca7d265454 100644 --- a/mvpoly/tests/monomials.rs +++ b/mvpoly/tests/monomials.rs @@ -712,3 +712,33 @@ fn test_from_expr_ec_addition() { assert_eq!(eval, exp_eval); } } + +#[test] +fn test_compute_combined_cross_terms_expected_nb_of_cross_terms() { + mvpoly::pbt::test_compute_combined_cross_terms_expected_nb_of_cross_terms::< + Fp, + 4, + 2, + Sparse, + >(); + mvpoly::pbt::test_compute_combined_cross_terms_expected_nb_of_cross_terms::< + Fp, + 6, + 3, + Sparse, + >(); + mvpoly::pbt::test_compute_combined_cross_terms_expected_nb_of_cross_terms::< + Fp, + 4, + 8, + Sparse, + >(); +} + +#[test] +fn test_compute_combined_cross_terms_right_alpha_null() { + mvpoly::pbt::test_compute_combined_cross_terms_right_alpha_null::>( + ); + mvpoly::pbt::test_compute_combined_cross_terms_right_alpha_null::>(); + mvpoly::pbt::test_compute_combined_cross_terms_right_alpha_null::>(); +}