From 505f22e7288491bbeb12fa0f02c50fe66396cfb9 Mon Sep 17 00:00:00 2001 From: v0-e Date: Mon, 9 Dec 2024 20:40:59 +0000 Subject: [PATCH 1/3] Remove `RefCounter` --- latticefold/src/nifs/folding.rs | 4 +- latticefold/src/nifs/folding/tests/mod.rs | 20 +++---- latticefold/src/nifs/folding/utils.rs | 53 ++++++------------- latticefold/src/nifs/linearization.rs | 8 +-- .../src/nifs/linearization/tests/mod.rs | 4 +- latticefold/src/nifs/linearization/utils.rs | 6 +-- latticefold/src/utils/sumcheck.rs | 16 +++--- latticefold/src/utils/sumcheck/prover.rs | 12 +---- latticefold/src/utils/sumcheck/utils.rs | 12 +++-- 9 files changed, 53 insertions(+), 82 deletions(-) diff --git a/latticefold/src/nifs/folding.rs b/latticefold/src/nifs/folding.rs index cbcd3d89..626c03f8 100644 --- a/latticefold/src/nifs/folding.rs +++ b/latticefold/src/nifs/folding.rs @@ -159,7 +159,7 @@ impl> FoldingProver( log_m, - &f_hat_mles, + f_hat_mles.clone(), &alpha_s, &prechallenged_Ms_1, &prechallenged_Ms_2, @@ -172,7 +172,7 @@ impl> FoldingProver( ccs.s, - &f_hat_mles, + f_hat_mles, &alpha_s, &prechallenged_Ms_1, &prechallenged_Ms_2, @@ -277,7 +277,7 @@ fn test_get_sumcheck_randomness() { // Compute sumcheck proof let (_, prover_state) = - MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn); + MLSumcheck::prove_as_subprotocol(&mut transcript, g_mles, ccs.s, g_degree, comb_fn); // Derive randomness let r_0 = LFFoldingProver::>::get_sumcheck_randomness( prover_state, @@ -317,7 +317,7 @@ fn test_get_thetas() { .unwrap(); let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>( ccs.s, - &f_hat_mles, + f_hat_mles.clone(), &alpha_s, &prechallenged_Ms_1, &prechallenged_Ms_2, @@ -331,7 +331,7 @@ fn test_get_thetas() { |vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::(vals, &mu_s) }; let (_, prover_state) = - MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn); + MLSumcheck::prove_as_subprotocol(&mut transcript, g_mles, ccs.s, g_degree, comb_fn); let r_0 = LFFoldingProver::>::get_sumcheck_randomness( prover_state, ); @@ -389,7 +389,7 @@ fn test_get_etas() { .unwrap(); let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>( ccs.s, - &f_hat_mles, + f_hat_mles, &alpha_s, &prechallenged_Ms_1, &prechallenged_Ms_2, @@ -403,7 +403,7 @@ fn test_get_etas() { |vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::(vals, &mu_s) }; let (_, prover_state) = - MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn); + MLSumcheck::prove_as_subprotocol(&mut transcript, g_mles, ccs.s, g_degree, comb_fn); let r_0 = LFFoldingProver::>::get_sumcheck_randomness( prover_state, ); @@ -489,7 +489,7 @@ fn test_prepare_public_output() { .unwrap(); let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>( ccs.s, - &f_hat_mles, + f_hat_mles.clone(), &alpha_s, &prechallenged_Ms_1, &prechallenged_Ms_2, @@ -503,7 +503,7 @@ fn test_prepare_public_output() { |vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::(vals, &mu_s) }; let (_, prover_state) = - MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn); + MLSumcheck::prove_as_subprotocol(&mut transcript, g_mles, ccs.s, g_degree, comb_fn); let r_0 = LFFoldingProver::>::get_sumcheck_randomness( prover_state, ); @@ -571,7 +571,7 @@ fn test_compute_f_0() { .unwrap(); let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>( ccs.s, - &f_hat_mles, + f_hat_mles.clone(), &alpha_s, &prechallenged_Ms_1, &prechallenged_Ms_2, @@ -585,7 +585,7 @@ fn test_compute_f_0() { |vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::(vals, &mu_s) }; let (_, prover_state) = - MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn); + MLSumcheck::prove_as_subprotocol(&mut transcript, g_mles, ccs.s, g_degree, comb_fn); let r_0 = LFFoldingProver::>::get_sumcheck_randomness( prover_state, ); diff --git a/latticefold/src/nifs/folding/utils.rs b/latticefold/src/nifs/folding/utils.rs index 1e917822..587ff26e 100644 --- a/latticefold/src/nifs/folding/utils.rs +++ b/latticefold/src/nifs/folding/utils.rs @@ -7,7 +7,6 @@ use ark_std::iterable::Iterable; // use ark_std::sync::Arc; use cyclotomic_rings::{rings::SuitableRing, rotation::rot_lin_combination}; -use lattirust_poly::polynomials::RefCounter; use lattirust_ring::{cyclotomic_ring::CRT, Ring}; use crate::ark_base::*; @@ -97,14 +96,14 @@ pub(super) fn get_rhos< #[allow(clippy::too_many_arguments)] pub(super) fn create_sumcheck_polynomial( log_m: usize, - f_hat_mles: &[Vec>], + f_hat_mles: Vec>>, alpha_s: &[NTT], challenged_Ms_1: &DenseMultilinearExtension, challenged_Ms_2: &DenseMultilinearExtension, r_s: &[Vec], beta_s: &[NTT], mu_s: &[NTT], -) -> Result<(Vec>>, usize), FoldingError> { +) -> Result<(Vec>, usize), FoldingError> { if alpha_s.len() != 2 * DP::K || f_hat_mles.len() != 2 * DP::K || r_s.len() != 2 * DP::K @@ -123,21 +122,8 @@ pub(super) fn create_sumcheck_polynomial>>> = f_hat_mles - .iter() - .map(|f_hat_mles_i| { - f_hat_mles_i - .clone() - .into_iter() - .map(RefCounter::from) - .collect::>() - }) - .collect(); - let len = 2 + 2 + // g1 + g3 - 1 + DP::K + DP::K; // g2 + 1 + f_hat_mles.len() * f_hat_mles[0].len(); // g2 let mut mles = Vec::with_capacity(len); // We assume here that decomposition subprotocol puts the same r challenge point @@ -161,15 +147,8 @@ pub(super) fn create_sumcheck_polynomial( } fn prepare_g1_and_3_k_mles_list( - mles: &mut Vec>>, - r_i_eq: RefCounter>, - f_hat_mle_s: &[Vec>>], + mles: &mut Vec>, + r_i_eq: DenseMultilinearExtension, + f_hat_mle_s: &[Vec>], alpha_s: &[NTT], challenged_Ms: &DenseMultilinearExtension, ) { @@ -345,7 +324,7 @@ fn prepare_g1_and_3_k_mles_list( for (fi_hat_mle_s, alpha_i) in f_hat_mle_s.iter().zip(alpha_s.iter()) { let mut mle = DenseMultilinearExtension::zero(); for fi_hat_mle in fi_hat_mle_s.iter().rev() { - mle += fi_hat_mle.as_ref(); + mle += fi_hat_mle; mle *= *alpha_i; } combined_mle += mle; @@ -354,14 +333,16 @@ fn prepare_g1_and_3_k_mles_list( combined_mle += challenged_Ms; mles.push(r_i_eq); - mles.push(RefCounter::from(combined_mle)); + mles.push(combined_mle); } fn prepare_g2_i_mle_list( - mles: &mut Vec>>, - fi_hat_mle_s: &[RefCounter>], + mles: &mut Vec>, + beta_eq_x: DenseMultilinearExtension, + f_hat_mles: Vec>>, ) { - for fi_hat_mle in fi_hat_mle_s.iter() { - mles.push(fi_hat_mle.clone()); - } + mles.push(beta_eq_x); + f_hat_mles + .into_iter() + .for_each(|fhms| fhms.into_iter().for_each(|fhm| mles.push(fhm))) } diff --git a/latticefold/src/nifs/linearization.rs b/latticefold/src/nifs/linearization.rs index b69546fc..bc547d33 100644 --- a/latticefold/src/nifs/linearization.rs +++ b/latticefold/src/nifs/linearization.rs @@ -1,5 +1,5 @@ use cyclotomic_rings::rings::SuitableRing; -use lattirust_poly::{mle::DenseMultilinearExtension, polynomials::RefCounter}; +use lattirust_poly::mle::DenseMultilinearExtension; use lattirust_ring::OverField; use utils::{compute_u, prepare_lin_sumcheck_polynomial, sumcheck_polynomial_comb_fn}; @@ -84,7 +84,7 @@ impl> LFLinearizationProver { ccs: &CCS, ) -> Result< ( - Vec>>, + Vec>, usize, Vec>, ), @@ -106,7 +106,7 @@ impl> LFLinearizationProver { /// Step 2: Run linearization sum-check protocol. fn generate_sumcheck_proof( transcript: &mut impl Transcript, - mles: &[RefCounter>], + mles: Vec>, nvars: usize, degree: usize, comb_fn: impl Fn(&[NTT]) -> NTT + Sync + Send, @@ -161,7 +161,7 @@ impl> LinearizationProver // Run sumcheck protocol. let (sumcheck_proof, point_r) = - Self::generate_sumcheck_proof(transcript, &g_mles, ccs.s, g_degree, comb_fn)?; + Self::generate_sumcheck_proof(transcript, g_mles, ccs.s, g_degree, comb_fn)?; // Step 3: Compute v, u_vector. let (point_r, v, u) = Self::compute_evaluation_vectors(wit, &point_r, &Mz_mles)?; diff --git a/latticefold/src/nifs/linearization/tests/mod.rs b/latticefold/src/nifs/linearization/tests/mod.rs index e4b1fecf..2378d680 100644 --- a/latticefold/src/nifs/linearization/tests/mod.rs +++ b/latticefold/src/nifs/linearization/tests/mod.rs @@ -116,7 +116,7 @@ fn test_generate_sumcheck() { let (_, point_r) = LFLinearizationProver::>::generate_sumcheck_proof( &mut transcript, - &g_mles, + g_mles, ccs.s, g_degree, comb_fn, @@ -148,7 +148,7 @@ fn prepare_test_vectors> let (_, point_r) = LFLinearizationProver::>::generate_sumcheck_proof( &mut transcript, - &g_mles, + g_mles, ccs.s, g_degree, comb_fn, diff --git a/latticefold/src/nifs/linearization/utils.rs b/latticefold/src/nifs/linearization/utils.rs index 60244775..812b651b 100644 --- a/latticefold/src/nifs/linearization/utils.rs +++ b/latticefold/src/nifs/linearization/utils.rs @@ -1,7 +1,7 @@ use crate::{ark_base::Vec, utils::mle_helpers::evaluate_mles}; use ark_ff::PrimeField; -use lattirust_poly::{mle::DenseMultilinearExtension, polynomials::RefCounter}; +use lattirust_poly::mle::DenseMultilinearExtension; use lattirust_ring::OverField; use crate::nifs::{error::LinearizationError, CCS}; @@ -66,7 +66,7 @@ pub fn prepare_lin_sumcheck_polynomial( M_mles: &[DenseMultilinearExtension], S: &[Vec], beta_s: &[NTT], -) -> Result<(Vec>>, usize), LinearizationError> { +) -> Result<(Vec>, usize), LinearizationError> { let len = 1 + c .iter() .enumerate() @@ -78,7 +78,7 @@ pub fn prepare_lin_sumcheck_polynomial( for (i, _) in c.iter().enumerate().filter(|(_, c)| !c.is_zero()) { for &j in &S[i] { - mles.push(RefCounter::new(M_mles[j].clone())); + mles.push(M_mles[j].clone()); } } diff --git a/latticefold/src/utils/sumcheck.rs b/latticefold/src/utils/sumcheck.rs index ecc83c94..cb1ef146 100644 --- a/latticefold/src/utils/sumcheck.rs +++ b/latticefold/src/utils/sumcheck.rs @@ -1,6 +1,6 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{fmt::Display, marker::PhantomData}; -use lattirust_poly::polynomials::{ArithErrors, DenseMultilinearExtension, RefCounter}; +use lattirust_poly::polynomials::{ArithErrors, DenseMultilinearExtension}; use lattirust_ring::{OverField, Ring}; use thiserror::Error; @@ -53,7 +53,7 @@ impl> MLSumcheck { /// Both of these allow this sumcheck to be better used as a part of a larger protocol. pub fn prove_as_subprotocol( transcript: &mut T, - mles: &[RefCounter>], + mles: Vec>, nvars: usize, degree: usize, comb_fn: impl Fn(&[R]) -> R + Sync + Send, @@ -113,7 +113,7 @@ mod tests { use crate::ark_base::*; use crate::transcript::poseidon::PoseidonTranscript; use crate::utils::sumcheck::utils::{rand_poly, rand_poly_comb_fn}; - use crate::utils::sumcheck::{DenseMultilinearExtension, MLSumcheck, Proof, RefCounter}; + use crate::utils::sumcheck::{DenseMultilinearExtension, MLSumcheck, Proof}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use ark_std::io::Cursor; use cyclotomic_rings::challenge_set::LatticefoldChallengeSet; @@ -123,11 +123,7 @@ mod tests { fn generate_sumcheck_proof( nvars: usize, mut rng: &mut (impl Rng + Sized), - ) -> ( - (Vec>>, usize), - R, - Proof, - ) + ) -> ((Vec>, usize), R, Proof) where R: SuitableRing, CS: LatticefoldChallengeSet, @@ -141,7 +137,7 @@ mod tests { let (proof, _) = MLSumcheck::prove_as_subprotocol( &mut transcript, - &poly_mles, + poly_mles.clone(), nvars, poly_degree, comb_fn, @@ -208,7 +204,7 @@ mod tests { let (proof, _) = MLSumcheck::prove_as_subprotocol( &mut transcript, - &poly_mles, + poly_mles, nvars, poly_degree, comb_fn, diff --git a/latticefold/src/utils/sumcheck/prover.rs b/latticefold/src/utils/sumcheck/prover.rs index 4789b671..31acdca9 100644 --- a/latticefold/src/utils/sumcheck/prover.rs +++ b/latticefold/src/utils/sumcheck/prover.rs @@ -2,10 +2,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{cfg_into_iter, cfg_iter_mut, vec::Vec}; -use lattirust_poly::{ - mle::MultilinearExtension, - polynomials::{DenseMultilinearExtension, RefCounter}, -}; +use lattirust_poly::{mle::MultilinearExtension, polynomials::DenseMultilinearExtension}; use lattirust_ring::{OverField, Ring}; use super::{verifier::VerifierMsg, IPForMLSumcheck}; @@ -37,7 +34,7 @@ pub struct ProverState { impl IPForMLSumcheck { /// initialize the prover to argue for the sum of polynomial over {0,1}^`num_vars` pub fn prover_init( - mles: &[RefCounter>], + mles: Vec>, nvars: usize, degree: usize, ) -> ProverState { @@ -45,11 +42,6 @@ impl IPForMLSumcheck { panic!("Attempt to prove a constant.") } - // create a deep copy of all unique MLExtensions - let mles = ark_std::cfg_iter!(mles) - .map(|x| x.as_ref().clone()) - .collect(); - ProverState { randomness: Vec::with_capacity(nvars), mles, diff --git a/latticefold/src/utils/sumcheck/utils.rs b/latticefold/src/utils/sumcheck/utils.rs index 6ac7d54d..3bf65d16 100644 --- a/latticefold/src/utils/sumcheck/utils.rs +++ b/latticefold/src/utils/sumcheck/utils.rs @@ -27,7 +27,7 @@ pub fn rand_poly( rng: &mut impl RngCore, ) -> Result< ( - (Vec>>, usize), + (Vec>, usize), Vec<(R, Vec)>, R, ), @@ -42,6 +42,10 @@ pub fn rand_poly( let num_multiplicands = rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); degree = num_multiplicands.max(degree); let (product, product_sum) = random_mle_list(nv, num_multiplicands, rng); + let product = product + .into_iter() + .map(|p| RefCounter::into_inner(p).unwrap()) + .collect::>(); let coefficient = R::rand(rng); mles.extend(product); @@ -92,13 +96,11 @@ pub fn eq_eval(x: &[R], y: &[R]) -> Result { /// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) /// over r, which is /// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) -pub fn build_eq_x_r( - r: &[R], -) -> Result>, ArithErrors> { +pub fn build_eq_x_r(r: &[R]) -> Result, ArithErrors> { let evals = build_eq_x_r_vec(r)?; let mle = DenseMultilinearExtension::from_evaluations_vec(r.len(), evals); - Ok(RefCounter::new(mle)) + Ok(mle) } /// This function build the eq(x, r) polynomial for any given r, and output the /// evaluation of eq(x, r) in its vector form. From 1b98f52bce272fef7012a93a5904b6f504ac9446 Mon Sep 17 00:00:00 2001 From: v0-e Date: Tue, 10 Dec 2024 16:58:15 +0000 Subject: [PATCH 2/3] Use `Vec::append()` in `g2` mle prep Co-authored-by: Ilia Vlasov --- latticefold/src/nifs/folding/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/latticefold/src/nifs/folding/utils.rs b/latticefold/src/nifs/folding/utils.rs index 587ff26e..110e6287 100644 --- a/latticefold/src/nifs/folding/utils.rs +++ b/latticefold/src/nifs/folding/utils.rs @@ -344,5 +344,5 @@ fn prepare_g2_i_mle_list( mles.push(beta_eq_x); f_hat_mles .into_iter() - .for_each(|fhms| fhms.into_iter().for_each(|fhm| mles.push(fhm))) + .for_each(|mut fhms| mles.append(&mut fhms)) } From 7af6fe74ce664a952c857d7e378d659ac5756614 Mon Sep 17 00:00:00 2001 From: v0-e Date: Tue, 10 Dec 2024 17:26:23 +0000 Subject: [PATCH 3/3] Remove mles from tests `generate_sumcheck_proof()` return --- latticefold/src/utils/sumcheck.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/latticefold/src/utils/sumcheck.rs b/latticefold/src/utils/sumcheck.rs index cb1ef146..3d4794a9 100644 --- a/latticefold/src/utils/sumcheck.rs +++ b/latticefold/src/utils/sumcheck.rs @@ -113,7 +113,7 @@ mod tests { use crate::ark_base::*; use crate::transcript::poseidon::PoseidonTranscript; use crate::utils::sumcheck::utils::{rand_poly, rand_poly_comb_fn}; - use crate::utils::sumcheck::{DenseMultilinearExtension, MLSumcheck, Proof}; + use crate::utils::sumcheck::{MLSumcheck, Proof}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use ark_std::io::Cursor; use cyclotomic_rings::challenge_set::LatticefoldChallengeSet; @@ -123,7 +123,7 @@ mod tests { fn generate_sumcheck_proof( nvars: usize, mut rng: &mut (impl Rng + Sized), - ) -> ((Vec>, usize), R, Proof) + ) -> (usize, R, Proof) where R: SuitableRing, CS: LatticefoldChallengeSet, @@ -137,12 +137,12 @@ mod tests { let (proof, _) = MLSumcheck::prove_as_subprotocol( &mut transcript, - poly_mles.clone(), + poly_mles, nvars, poly_degree, comb_fn, ); - ((poly_mles, poly_degree), sum, proof) + (poly_degree, sum, proof) } fn test_sumcheck() @@ -154,7 +154,7 @@ mod tests { let nvars = 5; for _ in 0..20 { - let ((_, poly_degree), sum, proof) = generate_sumcheck_proof::(nvars, &mut rng); + let (poly_degree, sum, proof) = generate_sumcheck_proof::(nvars, &mut rng); let mut transcript: PoseidonTranscript = PoseidonTranscript::default(); let res =