Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Dec 21, 2023
1 parent 1dddc5c commit 36f0d25
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ref-cast = "1.0.20"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev", version = "0.1.4" }
grumpkin-msm = { git = "https://github.com/lurk-lab/grumpkin-msm", branch = "dev", features = ["force-no-sort"] }
grumpkin-msm = { git = "https://github.com/lurk-lab/grumpkin-msm", branch = "dev", features = ["dont-implement-sort"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
halo2curves = { version = "0.5.0", features = ["bits", "derive_serde"] }
Expand Down
64 changes: 64 additions & 0 deletions src/provider/bn256_grumpkin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,26 @@ macro_rules! impl_traits {
self.to_repr().to_vec()
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name::Affine {
fn to_transcript_bytes(&self) -> Vec<u8> {
let (x, y, is_infinity_byte) = {
let coordinates = self.coordinates();
if coordinates.is_some().unwrap_u8() == 1 && ($name_curve_affine::identity() != *self) {
let c = coordinates.unwrap();
(*c.x(), *c.y(), u8::from(false))
} else {
($name::Base::zero(), $name::Base::zero(), u8::from(false))
}
};

x.to_repr()
.into_iter()
.chain(y.to_repr().into_iter())
.chain(std::iter::once(is_infinity_byte))
.collect()
}
}
};
}

Expand All @@ -195,3 +215,47 @@ impl_traits!(
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"
);

#[cfg(test)]
mod tests {
use ff::Field;
use rand::thread_rng;

use crate::provider::{
bn256_grumpkin::{bn256, grumpkin},
traits::DlogGroup,
util::msm::cpu_best_msm,
};

#[test]
fn test_bn256_msm_correctness() {
let npoints = 1usize << 16;
let points = bn256::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| bn256::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = bn256::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}

#[test]
fn test_grumpkin_msm_correctness() {
let npoints = 1usize << 16;
let points = grumpkin::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| grumpkin::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = grumpkin::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}
}
41 changes: 41 additions & 0 deletions src/provider/pasta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,44 @@ impl_traits!(
"40000000000000000000000000000000224698fc094cf91b992d30ed00000001",
"40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001"
);

#[cfg(test)]
mod tests {
use ff::Field;
use pasta_curves::{pallas, vesta};
use rand::thread_rng;

use crate::provider::{traits::DlogGroup, util::msm::cpu_best_msm};

#[test]
fn test_pallas_msm_correctness() {
let npoints = 1usize << 16;
let points = pallas::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| pallas::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = pallas::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}

#[test]
fn test_vesta_msm_correctness() {
let npoints = 1usize << 16;
let points = vesta::Point::from_label(b"test", npoints);

let mut rng = thread_rng();
let scalars = (0..npoints)
.map(|_| vesta::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let gpu_msm = vesta::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
}
}

0 comments on commit 36f0d25

Please sign in to comment.