From 7357e5e5bddc893da7bee2561be04b31107dde68 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 16 Jan 2024 03:18:01 -0500 Subject: [PATCH] More ergonomic usage of `batch_mul` with table (#757) * More ergonomic usage of `batch_mul` with table * Tweak --- ec/src/scalar_mul/mod.rs | 82 +++++++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/ec/src/scalar_mul/mod.rs b/ec/src/scalar_mul/mod.rs index bcf17843b..d67d769db 100644 --- a/ec/src/scalar_mul/mod.rs +++ b/ec/src/scalar_mul/mod.rs @@ -115,18 +115,47 @@ pub trait ScalarMul: /// ``` fn batch_mul(self, v: &[Self::ScalarField]) -> Vec { let table = BatchMulPreprocessing::new(self, v.len()); - self.batch_mul_with_preprocessing(v, &table) + Self::batch_mul_with_preprocessing(&table, v) } + /// Compute the vector v[0].G, v[1].G, ..., v[n-1].G, given: + /// - an element `g` + /// - a list `v` of n scalars + /// + /// This method allows the user to provide a precomputed table of multiples of `g`. + /// A more ergonomic way to call this would be to use [`BatchMulPreprocessing::batch_mul`]. + /// + /// # Example + /// ``` + /// use ark_std::{One, UniformRand}; + /// use ark_ec::pairing::Pairing; + /// use ark_test_curves::bls12_381::G1Projective as G; + /// use ark_test_curves::bls12_381::Fr; + /// use ark_ec::scalar_mul::*; + /// + /// // Compute G, s.G, s^2.G, ..., s^9.G + /// let mut rng = ark_std::test_rng(); + /// let max_degree = 10; + /// let s = Fr::rand(&mut rng); + /// let g = G::rand(&mut rng); + /// let mut powers_of_s = vec![Fr::one()]; + /// let mut cur = s; + /// for _ in 0..max_degree { + /// powers_of_s.push(cur); + /// cur *= &s; + /// } + /// let table = BatchMulPreprocessing::new(g, powers_of_s.len()); + /// let powers_of_g = G::batch_mul_with_preprocessing(&table, &powers_of_s); + /// let powers_of_g_2 = table.batch_mul(&powers_of_s); + /// let naive_powers_of_g: Vec = powers_of_s.iter().map(|e| g * e).collect(); + /// assert_eq!(powers_of_g, naive_powers_of_g); + /// assert_eq!(powers_of_g_2, naive_powers_of_g); + /// ``` fn batch_mul_with_preprocessing( - self, + table: &BatchMulPreprocessing, v: &[Self::ScalarField], - preprocessing: &BatchMulPreprocessing, ) -> Vec { - let result = cfg_iter!(v) - .map(|e| preprocessing.windowed_mul(e)) - .collect::>(); - Self::batch_convert_to_mul_base(&result) + table.batch_mul(v) } } @@ -137,25 +166,22 @@ pub trait ScalarMul: pub struct BatchMulPreprocessing { pub window: usize, pub max_scalar_size: usize, + pub max_num_scalars: usize, pub table: Vec>, } impl BatchMulPreprocessing { pub fn new(base: T, num_scalars: usize) -> Self { - let window = Self::compute_window_size(num_scalars); let scalar_size = T::ScalarField::MODULUS_BIT_SIZE as usize; - Self::with_window_and_scalar_size(base, window, scalar_size) - } - - fn compute_window_size(num_scalars: usize) -> usize { - if num_scalars < 32 { - 3 - } else { - ln_without_floats(num_scalars) - } + Self::with_num_scalars_and_scalar_size(base, num_scalars, scalar_size) } - pub fn with_window_and_scalar_size(base: T, window: usize, max_scalar_size: usize) -> Self { + pub fn with_num_scalars_and_scalar_size( + base: T, + max_num_scalars: usize, + max_scalar_size: usize, + ) -> Self { + let window = Self::compute_window_size(max_num_scalars); let in_window = 1 << window; let outerc = (max_scalar_size + window - 1) / window; let last_in_window = 1 << (max_scalar_size - (outerc - 1) * window); @@ -193,10 +219,30 @@ impl BatchMulPreprocessing { Self { window, max_scalar_size, + max_num_scalars, table, } } + pub fn compute_window_size(num_scalars: usize) -> usize { + if num_scalars < 32 { + 3 + } else { + ln_without_floats(num_scalars) + } + } + + pub fn batch_mul(&self, v: &[T::ScalarField]) -> Vec { + assert!( + v.len() <= self.max_num_scalars, + "number of scalars exceeds the maximum number of scalars supported by this table" + ); + let result = cfg_iter!(v) + .map(|e| self.windowed_mul(e)) + .collect::>(); + T::batch_convert_to_mul_base(&result) + } + fn windowed_mul(&self, scalar: &T::ScalarField) -> T { let outerc = (self.max_scalar_size + self.window - 1) / self.window; let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize;