Skip to content

Commit

Permalink
MVPoly/Cross-terms: add PBT
Browse files Browse the repository at this point in the history
  • Loading branch information
dannywillems committed Oct 16, 2024
1 parent 1b900cd commit 9b7b501
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 1 deletion.
136 changes: 135 additions & 1 deletion mvpoly/src/pbt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::MVPoly;
use ark_ff::PrimeField;
use rand::{seq::SliceRandom, Rng};
use std::ops::Neg;
use std::{collections::HashMap, ops::Neg};

pub fn test_mul_by_one<F: PrimeField, const N: usize, const D: usize, T: MVPoly<F, N, D>>() {
let mut rng = o1_utils::tests::make_test_rng(None);
Expand Down Expand Up @@ -587,3 +587,137 @@ pub fn test_is_constant<F: PrimeField, const N: usize, const D: usize, T: MVPoly
let p = unsafe { T::random(&mut rng, None) };
assert!(!p.is_constant());
}

pub fn test_compute_combined_cross_terms_expected_nb_of_cross_terms<
F: PrimeField,
const N: usize,
const D: usize,
T: MVPoly<F, N, D>,
>() {
let mut rng = o1_utils::tests::make_test_rng(None);
// Generate max 40 polynomials
let n = rng.gen_range(1..10);
let polys = (0..n)
.map(|_| unsafe { T::random(&mut rng, None) })
.collect::<Vec<_>>();
let eval1: [F; N] = std::array::from_fn(|_| F::rand(&mut rng));
let eval2: [F; N] = std::array::from_fn(|_| F::rand(&mut rng));
let u1 = F::rand(&mut rng);
let u2 = F::rand(&mut rng);

{
let alpha1 = F::rand(&mut rng);
let alpha2 = F::rand(&mut rng);

let combined_cross_terms = crate::compute_combined_cross_terms(
polys.clone(),
alpha1,
alpha2,
eval1,
eval2,
u1,
u2,
);
assert_eq!(combined_cross_terms.len(), D);
}

// Check that even if zero is given, we get the right number of cross terms
{
let alpha1 = F::zero();
let alpha2 = F::rand(&mut rng);
let combined_cross_terms = crate::compute_combined_cross_terms(
polys.clone(),
alpha1,
alpha2,
eval1,
eval2,
u1,
u2,
);
assert_eq!(combined_cross_terms.len(), D);
}

{
let alpha1 = F::rand(&mut rng);
let alpha2 = F::zero();
let combined_cross_terms = crate::compute_combined_cross_terms(
polys.clone(),
alpha1,
alpha2,
eval1,
eval2,
u1,
u2,
);
assert_eq!(combined_cross_terms.len(), D);
}

{
let alpha1 = F::zero();
let alpha2 = F::zero();
let combined_cross_terms =
crate::compute_combined_cross_terms(polys, alpha1, alpha2, eval1, eval2, u1, u2);
assert_eq!(combined_cross_terms.len(), D);
}
}

pub fn test_compute_combined_cross_terms_right_alpha_null<
F: PrimeField,
const N: usize,
const D: usize,
T: MVPoly<F, N, D>,
>() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p = unsafe { T::random(&mut rng, Some(D - 1)) };
let eval1: [F; N] = std::array::from_fn(|_| F::rand(&mut rng));
let eval2: [F; N] = std::array::from_fn(|_| F::rand(&mut rng));
let u1 = F::rand(&mut rng);
let u2 = F::rand(&mut rng);
let cross_terms = p.compute_cross_terms(&eval1, &eval2, u1, u2);

let alpha2 = F::zero();
{
let alpha1 = F::one();
let combined_cross_terms = crate::compute_combined_cross_terms(
vec![p.clone()],
alpha1,
alpha2,
eval1,
eval2,
u1,
u2,
);
let mut exp_cross_terms = cross_terms.clone();
exp_cross_terms.insert(D, F::zero());
assert_eq!(exp_cross_terms, combined_cross_terms);
}

{
let alpha1 = F::rand(&mut rng);
let combined_cross_terms = crate::compute_combined_cross_terms(
vec![p.clone()],
alpha1,
alpha2,
eval1,
eval2,
u1,
u2,
);
// alpha is never used as there is only one polynomial
let mut exp_cross_terms: HashMap<usize, F> = cross_terms.clone();
exp_cross_terms.insert(D, F::zero());
assert_eq!(exp_cross_terms, combined_cross_terms);
}

// Special case with alpha1 = 0
{
let alpha1 = F::zero();
let combined_cross_terms =
crate::compute_combined_cross_terms(vec![p], alpha1, alpha2, eval1, eval2, u1, u2);
let mut zeroes: HashMap<usize, F> = HashMap::new();
for i in 1..=D {
zeroes.insert(i, F::zero());
}
assert_eq!(combined_cross_terms, zeroes);
}
}
30 changes: 30 additions & 0 deletions mvpoly/tests/monomials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,33 @@ fn test_from_expr_ec_addition() {
assert_eq!(eval, exp_eval);
}
}

#[test]
fn test_compute_combined_cross_terms_expected_nb_of_cross_terms() {
mvpoly::pbt::test_compute_combined_cross_terms_expected_nb_of_cross_terms::<
Fp,
4,
2,
Sparse<Fp, 4, 2>,
>();
mvpoly::pbt::test_compute_combined_cross_terms_expected_nb_of_cross_terms::<
Fp,
6,
3,
Sparse<Fp, 6, 3>,
>();
mvpoly::pbt::test_compute_combined_cross_terms_expected_nb_of_cross_terms::<
Fp,
4,
8,
Sparse<Fp, 4, 8>,
>();
}

#[test]
fn test_compute_combined_cross_terms_right_alpha_null() {
mvpoly::pbt::test_compute_combined_cross_terms_right_alpha_null::<Fp, 15, 2, Sparse<Fp, 15, 2>>(
);
mvpoly::pbt::test_compute_combined_cross_terms_right_alpha_null::<Fp, 6, 3, Sparse<Fp, 6, 3>>();
mvpoly::pbt::test_compute_combined_cross_terms_right_alpha_null::<Fp, 4, 8, Sparse<Fp, 4, 8>>();
}

0 comments on commit 9b7b501

Please sign in to comment.