Skip to content

Commit

Permalink
Merge pull request #4 from Justuxs/zkcrypto-msm
Browse files Browse the repository at this point in the history
Zkcrypto msm
  • Loading branch information
Justuxs authored Nov 23, 2023
2 parents 79cb783 + ae9fcb2 commit dcf675e
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 12 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions zkcrypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rand = { version = "0.8.5", optional = true }
libc = { version = "0.2.148", default-features = false }
rayon = { version = "1.8.0", optional = true }
subtle = "2.5.0"
byteorder = "1.5.0"

[dev-dependencies]
criterion = "0.5.1"
Expand Down
20 changes: 9 additions & 11 deletions zkcrypto/src/fft_g1.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use crate::consts::G1_GENERATOR;
use crate::kzg_proofs::FFTSettings;
use crate::kzg_types::{ZFr as BlstFr, ZG1};

use crate::kzg_types::{ZFr, ZG1};
use crate::multiscalar_mul::msm_variable_base;
use kzg::{Fr as KzgFr, G1Mul};
use kzg::{FFTG1, G1};
use std::ops::MulAssign;

pub fn g1_linear_combination(out: &mut ZG1, points: &[ZG1], scalars: &[BlstFr], len: usize) {
*out = ZG1::default();
for i in 0..len {
let tmp = points[i].mul(&scalars[i]);
*out = out.add_or_dbl(&tmp);
}
#[warn(unused_variables)]
pub fn g1_linear_combination(out: &mut ZG1, points: &[ZG1], scalars: &[ZFr], _len: usize) {
let g1 = msm_variable_base(points, scalars);
out.proj = g1
}
pub fn make_data(data: usize) -> Vec<ZG1> {
let mut vec = Vec::new();
Expand Down Expand Up @@ -46,7 +44,7 @@ impl FFTG1<ZG1> for FFTSettings {
fft_g1_fast(&mut ret, data, 1, roots, stride, 1);

if inverse {
let inv_fr_len = BlstFr::from_u64(data.len() as u64).inverse();
let inv_fr_len = ZFr::from_u64(data.len() as u64).inverse();
ret[..data.len()]
.iter_mut()
.for_each(|f| f.proj.mul_assign(&inv_fr_len.fr));
Expand All @@ -59,7 +57,7 @@ pub fn fft_g1_slow(
ret: &mut [ZG1],
data: &[ZG1],
stride: usize,
roots: &[BlstFr],
roots: &[ZFr],
roots_stride: usize,
_width: usize,
) {
Expand All @@ -78,7 +76,7 @@ pub fn fft_g1_fast(
ret: &mut [ZG1],
data: &[ZG1],
stride: usize,
roots: &[BlstFr],
roots: &[ZFr],
roots_stride: usize,
_width: usize,
) {
Expand Down
17 changes: 17 additions & 0 deletions zkcrypto/src/kzg_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ impl ZFr {
pub fn to_blst_fr(&self) -> blst_fr {
pc_fr_into_blst_fr(self.fr)
}

pub fn converter(points: &[ZFr]) -> Vec<Scalar> {
let mut result = Vec::new();

for zg1 in points {
result.push(zg1.fr);
}
result
}
}

impl KzgFr for ZFr {
Expand Down Expand Up @@ -300,6 +309,14 @@ impl ZG1 {
proj: G1Projective::from(&p),
}
}
pub fn converter(points: &[ZG1]) -> Vec<G1Projective> {
let mut result = Vec::new();

for zg1 in points {
result.push(zg1.proj);
}
result
}
}

impl From<blst_p1> for ZG1 {
Expand Down
2 changes: 1 addition & 1 deletion zkcrypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ pub mod fft_g1;
pub mod fk20_proofs;
pub mod kzg_proofs;
pub mod kzg_types;
mod multiscalar_mul;
pub mod poly;
pub mod recover;
pub mod utils;
pub mod zero_poly;

trait Eq<T> {
fn equals(&self, other: &T) -> bool;
}
Expand Down
194 changes: 194 additions & 0 deletions zkcrypto/src/multiscalar_mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//! Multiscalar multiplication implementation using pippenger algorithm.
// use dusk_bytes::Serializable;

// use alloc::vec::*;

use crate::kzg_types::{ZFr, ZG1};
use bls12_381::{G1Projective, Scalar};

pub fn divn(mut scalar: Scalar, mut n: u32) -> Scalar {
if n >= 256 {
return Scalar::from(0);
}

while n >= 64 {
let mut t = 0;
for i in scalar.0.iter_mut().rev() {
core::mem::swap(&mut t, i);
}
n -= 64;
}

if n > 0 {
let mut t = 0;
for i in scalar.0.iter_mut().rev() {
let t2 = *i << (64 - n);
*i >>= n;
*i |= t;
t = t2;
}
}

scalar
}

/// Performs a Variable Base Multiscalar Multiplication.
#[allow(clippy::needless_collect)]
pub fn msm_variable_base(points_zg1: &[ZG1], zfrscalars: &[ZFr]) -> G1Projective {
let g1_projective_vec = ZG1::converter(points_zg1);
let points = g1_projective_vec.as_slice();

let scalars_vec = ZFr::converter(zfrscalars);
let scalars = scalars_vec.as_slice();

#[cfg(feature = "parallel")]
use rayon::prelude::*;

let c = if scalars.len() < 32 {
3
} else {
ln_without_floats(scalars.len()) + 2
};

let num_bits = 255usize;
let fr_one = Scalar::one();

let zero = G1Projective::identity();
let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();

#[cfg(feature = "parallel")]
let window_starts_iter = window_starts.into_par_iter();
#[cfg(not(feature = "parallel"))]
let window_starts_iter = window_starts.into_iter();

// Each window is of size `c`.
// We divide up the bits 0..num_bits into windows of size `c`, and
// in parallel process each such window.
let window_sums: Vec<_> = window_starts_iter
.map(|w_start| {
let mut res = zero;
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets
let mut buckets = vec![zero; (1 << c) - 1];
scalars
.iter()
.zip(points)
.filter(|(s, _)| !(*s == &Scalar::zero()))
.for_each(|(&scalar, base)| {
if scalar == fr_one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res = res.add(base);
}
} else {
let mut scalar = Scalar::montgomery_reduce(
scalar.0[0],
scalar.0[1],
scalar.0[2],
scalar.0[3],
0,
0,
0,
0,
);

// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar = divn(scalar, w_start as u32);
// We mod the remaining bits by the window size.
let scalar = scalar.0[0] % (1 << c);

// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize] =
buckets[(scalar - 1) as usize].add(base);
}
}
});

let mut running_sum = G1Projective::identity();
for b in buckets.into_iter().rev() {
running_sum += b;
res += &running_sum;
}

res
})
.collect();

// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();
// We're traversing windows from high to low.
window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total = total.double();
}
total
})
+ lowest
}

fn ln_without_floats(a: usize) -> usize {
// log2(a) * ln(2)
(log2(a) * 69 / 100) as usize
}
fn log2(x: usize) -> u32 {
if x <= 1 {
return 0;
}

let n = x.leading_zeros();
core::mem::size_of::<usize>() as u32 * 8 - n
}

/*
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn pippenger_test() {
// Reuse points across different tests
let mut n = 512;
let x = Scalar::from(2128506u64).invert().unwrap();
let y = Scalar::from(4443282u64).invert().unwrap();
let points = (0..n)
.map(|i| G1Projective::generator() * Scalar::from(1 + i as u64))
.collect::<Vec<_>>();
let scalars = (0..n)
.map(|i| x + (Scalar::from(i as u64) * y))
.collect::<Vec<_>>(); // fast way to make ~random but deterministic scalars
let premultiplied: Vec<G1Projective> = scalars
.iter()
.zip(points.iter())
.map(|(sc, pt)| pt * sc)
.collect();
while n > 0 {
let scalars = &scalars[0..n];
let points = &points[0..n];
let control: G1Projective = premultiplied[0..n].iter().sum();
let subject = pippenger(
points.to_owned().into_iter(),
scalars.to_owned().into_iter(),
);
assert_eq!(subject, control);
n = n / 2;
}
}
#[test]
fn msm_variable_base_test() {
let points = vec![G1Affine::generator()];
let scalars = vec![Scalar::from(100u64)];
let premultiplied = G1Projective::generator() * Scalar::from(100u64);
let subject = msm_variable_base(&points, &scalars);
assert_eq!(subject, premultiplied);
}
}
*/

0 comments on commit dcf675e

Please sign in to comment.