Skip to content

Commit

Permalink
More ergonomic usage of batch_mul with table (arkworks-rs#757)
Browse files Browse the repository at this point in the history
* More ergonomic usage of `batch_mul` with table

* Tweak
  • Loading branch information
Pratyush authored Jan 16, 2024
1 parent d42b7bc commit 7357e5e
Showing 1 changed file with 64 additions and 18 deletions.
82 changes: 64 additions & 18 deletions ec/src/scalar_mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,47 @@ pub trait ScalarMul:
/// ```
fn batch_mul(self, v: &[Self::ScalarField]) -> Vec<Self::MulBase> {
let table = BatchMulPreprocessing::new(self, v.len());
self.batch_mul_with_preprocessing(v, &table)
Self::batch_mul_with_preprocessing(&table, v)
}

/// Compute the vector v[0].G, v[1].G, ..., v[n-1].G, given:
/// - an element `g`
/// - a list `v` of n scalars
///
/// This method allows the user to provide a precomputed table of multiples of `g`.
/// A more ergonomic way to call this would be to use [`BatchMulPreprocessing::batch_mul`].
///
/// # Example
/// ```
/// use ark_std::{One, UniformRand};
/// use ark_ec::pairing::Pairing;
/// use ark_test_curves::bls12_381::G1Projective as G;
/// use ark_test_curves::bls12_381::Fr;
/// use ark_ec::scalar_mul::*;
///
/// // Compute G, s.G, s^2.G, ..., s^9.G
/// let mut rng = ark_std::test_rng();
/// let max_degree = 10;
/// let s = Fr::rand(&mut rng);
/// let g = G::rand(&mut rng);
/// let mut powers_of_s = vec![Fr::one()];
/// let mut cur = s;
/// for _ in 0..max_degree {
/// powers_of_s.push(cur);
/// cur *= &s;
/// }
/// let table = BatchMulPreprocessing::new(g, powers_of_s.len());
/// let powers_of_g = G::batch_mul_with_preprocessing(&table, &powers_of_s);
/// let powers_of_g_2 = table.batch_mul(&powers_of_s);
/// let naive_powers_of_g: Vec<G> = powers_of_s.iter().map(|e| g * e).collect();
/// assert_eq!(powers_of_g, naive_powers_of_g);
/// assert_eq!(powers_of_g_2, naive_powers_of_g);
/// ```
fn batch_mul_with_preprocessing(
self,
table: &BatchMulPreprocessing<Self>,
v: &[Self::ScalarField],
preprocessing: &BatchMulPreprocessing<Self>,
) -> Vec<Self::MulBase> {
let result = cfg_iter!(v)
.map(|e| preprocessing.windowed_mul(e))
.collect::<Vec<_>>();
Self::batch_convert_to_mul_base(&result)
table.batch_mul(v)
}
}

Expand All @@ -137,25 +166,22 @@ pub trait ScalarMul:
pub struct BatchMulPreprocessing<T: ScalarMul> {
pub window: usize,
pub max_scalar_size: usize,
pub max_num_scalars: usize,
pub table: Vec<Vec<T::MulBase>>,
}

impl<T: ScalarMul> BatchMulPreprocessing<T> {
pub fn new(base: T, num_scalars: usize) -> Self {
let window = Self::compute_window_size(num_scalars);
let scalar_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
Self::with_window_and_scalar_size(base, window, scalar_size)
}

fn compute_window_size(num_scalars: usize) -> usize {
if num_scalars < 32 {
3
} else {
ln_without_floats(num_scalars)
}
Self::with_num_scalars_and_scalar_size(base, num_scalars, scalar_size)
}

pub fn with_window_and_scalar_size(base: T, window: usize, max_scalar_size: usize) -> Self {
pub fn with_num_scalars_and_scalar_size(
base: T,
max_num_scalars: usize,
max_scalar_size: usize,
) -> Self {
let window = Self::compute_window_size(max_num_scalars);
let in_window = 1 << window;
let outerc = (max_scalar_size + window - 1) / window;
let last_in_window = 1 << (max_scalar_size - (outerc - 1) * window);
Expand Down Expand Up @@ -193,10 +219,30 @@ impl<T: ScalarMul> BatchMulPreprocessing<T> {
Self {
window,
max_scalar_size,
max_num_scalars,
table,
}
}

pub fn compute_window_size(num_scalars: usize) -> usize {
if num_scalars < 32 {
3
} else {
ln_without_floats(num_scalars)
}
}

pub fn batch_mul(&self, v: &[T::ScalarField]) -> Vec<T::MulBase> {
assert!(
v.len() <= self.max_num_scalars,
"number of scalars exceeds the maximum number of scalars supported by this table"
);
let result = cfg_iter!(v)
.map(|e| self.windowed_mul(e))
.collect::<Vec<_>>();
T::batch_convert_to_mul_base(&result)
}

fn windowed_mul(&self, scalar: &T::ScalarField) -> T {
let outerc = (self.max_scalar_size + self.window - 1) / self.window;
let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
Expand Down

0 comments on commit 7357e5e

Please sign in to comment.