Skip to content

Commit

Permalink
ACVP uninlined inner functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jschneider-bensch committed Oct 31, 2024
1 parent 076fef6 commit 30e98da
Showing 1 changed file with 150 additions and 118 deletions.
268 changes: 150 additions & 118 deletions libcrux-ml-dsa/tests/acvp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ struct ResultPromptTestGroup {

#[test]
fn keygen() {
use libcrux_ml_dsa::*;

let prompts: Prompts<KeyGenPromptTestGroup> = read("keygen", "prompt.json");
assert!(prompts.algorithm == "ML-DSA");
assert!(prompts.revision == "FIPS204");
Expand All @@ -83,37 +81,50 @@ fn keygen() {
eprintln!("{parameter_set}");

for test in kat.tests {
eprintln!(" {}", test.tcId);
fn check<const VK_LEN: usize, const SK_LEN: usize>(
keys: MLDSAKeyPair<VK_LEN, SK_LEN>,
result: &KeyGenResult,
) {
assert_eq!(result.pk, keys.verification_key.as_slice());
assert_eq!(result.sk, keys.signing_key.as_slice());
}

let expected_result = results
.testGroups
.iter()
.find(|tg| tg.tgId == kat.tgId)
.unwrap()
.tests
.iter()
.find(|t| t.tcId == test.tcId)
.unwrap();

match parameter_set.as_str() {
"ML-DSA-44" => check(ml_dsa_44::generate_key_pair(test.seed), expected_result),

"ML-DSA-65" => check(ml_dsa_65::generate_key_pair(test.seed), expected_result),

"ML-DSA-87" => check(ml_dsa_87::generate_key_pair(test.seed), expected_result),
_ => unimplemented!(),
}
keygen_inner(test, &results, kat.tgId, &parameter_set);
}
}
}

#[inline(never)]
#[allow(non_snake_case)]
fn keygen_inner(
test: KeyGenPrompt,
results: &Results<ResultPromptTestGroup>,
tgId: usize,
parameter_set: &String,
) {
use libcrux_ml_dsa::*;
eprintln!(" {}", test.tcId);
#[inline(never)]
fn check<const VK_LEN: usize, const SK_LEN: usize>(
keys: MLDSAKeyPair<VK_LEN, SK_LEN>,
result: &KeyGenResult,
) {
assert_eq!(result.pk, keys.verification_key.as_slice());
assert_eq!(result.sk, keys.signing_key.as_slice());
}

let expected_result = results
.testGroups
.iter()
.find(|tg| tg.tgId == tgId)
.unwrap()
.tests
.iter()
.find(|t| t.tcId == test.tcId)
.unwrap();

match parameter_set.as_str() {
"ML-DSA-44" => check(ml_dsa_44::generate_key_pair(test.seed), expected_result),

"ML-DSA-65" => check(ml_dsa_65::generate_key_pair(test.seed), expected_result),

"ML-DSA-87" => check(ml_dsa_87::generate_key_pair(test.seed), expected_result),
_ => unimplemented!(),
}
}

fn read<T: DeserializeOwned>(variant: &str, file: &str) -> T {
let katfile_path = Path::new("tests")
.join("kats")
Expand All @@ -128,8 +139,6 @@ fn read<T: DeserializeOwned>(variant: &str, file: &str) -> T {

#[test]
fn siggen() {
use libcrux_ml_dsa::*;

let prompts: Prompts<SigGenPromptTestGroup> = read("siggen", "prompt.json");
assert!(prompts.algorithm == "ML-DSA");
assert!(prompts.revision == "FIPS204");
Expand All @@ -148,59 +157,69 @@ fn siggen() {
eprintln!("{parameter_set}");

for test in kat.tests {
eprintln!(" {}", test.tcId);
let expected_result = results
.testGroups
.iter()
.find(|tg| tg.tgId == kat.tgId)
.unwrap()
.tests
.iter()
.find(|t| t.tcId == test.tcId)
.unwrap();

let Randomness(rnd) = test.rnd.unwrap_or(Randomness([0u8; 32]));

match parameter_set.as_str() {
"ML-DSA-44" => {
let signature = ml_dsa_44::sign_internal(
&MLDSASigningKey(test.sk.try_into().unwrap()),
&test.message,
rnd,
)
.unwrap();
assert_eq!(signature.as_slice(), expected_result.signature);
}

"ML-DSA-65" => {
let signature = ml_dsa_65::sign_internal(
&MLDSASigningKey(test.sk.try_into().unwrap()),
&test.message,
rnd,
)
.unwrap();
assert_eq!(signature.as_slice(), expected_result.signature);
}

"ML-DSA-87" => {
let signature = ml_dsa_87::sign_internal(
&MLDSASigningKey(test.sk.try_into().unwrap()),
&test.message,
rnd,
)
.unwrap();
assert_eq!(signature.as_slice(), expected_result.signature);
}
_ => unimplemented!(),
}
siggen_inner(test, &results, kat.tgId, &parameter_set);
}
}
}

#[test]
fn sigver() {
#[inline(never)]
#[allow(non_snake_case)]
fn siggen_inner(
test: SigGenTest,
results: &Results<ResultSigGenTestGroup>,
tgId: usize,
parameter_set: &String,
) {
use libcrux_ml_dsa::*;
eprintln!(" {}", test.tcId);
let expected_result = results
.testGroups
.iter()
.find(|tg| tg.tgId == tgId)
.unwrap()
.tests
.iter()
.find(|t| t.tcId == test.tcId)
.unwrap();

let Randomness(rnd) = test.rnd.unwrap_or(Randomness([0u8; 32]));

match parameter_set.as_str() {
"ML-DSA-44" => {
let signature = ml_dsa_44::sign_internal(
&MLDSASigningKey(test.sk.try_into().unwrap()),
&test.message,
rnd,
)
.unwrap();
assert_eq!(signature.as_slice(), expected_result.signature);
}

"ML-DSA-65" => {
let signature = ml_dsa_65::sign_internal(
&MLDSASigningKey(test.sk.try_into().unwrap()),
&test.message,
rnd,
)
.unwrap();
assert_eq!(signature.as_slice(), expected_result.signature);
}

"ML-DSA-87" => {
let signature = ml_dsa_87::sign_internal(
&MLDSASigningKey(test.sk.try_into().unwrap()),
&test.message,
rnd,
)
.unwrap();
assert_eq!(signature.as_slice(), expected_result.signature);
}
_ => unimplemented!(),
}
}

#[test]
fn sigver() {
let prompts: Prompts<SigVerPromptTestGroup> = read("sigver", "prompt.json");
assert!(prompts.algorithm == "ML-DSA");
assert!(prompts.revision == "FIPS204");
Expand All @@ -219,47 +238,60 @@ fn sigver() {
eprintln!("{parameter_set}");

for test in kat.tests {
eprintln!(" {}", test.tcId);
let expected_result = results
.testGroups
.iter()
.find(|tg| tg.tgId == kat.tgId)
.unwrap()
.tests
.iter()
.find(|t| t.tcId == test.tcId)
.unwrap();

match parameter_set.as_str() {
"ML-DSA-44" => {
let valid = ml_dsa_44::verify_internal(
&MLDSAVerificationKey(kat.pk.clone().try_into().unwrap()),
&test.message,
&MLDSASignature(test.signature.try_into().unwrap()),
);
assert_eq!(valid.is_ok(), expected_result.testPassed);
}

"ML-DSA-65" => {
let valid = ml_dsa_65::verify_internal(
&MLDSAVerificationKey(kat.pk.clone().try_into().unwrap()),
&test.message,
&MLDSASignature(test.signature.try_into().unwrap()),
);
assert_eq!(valid.is_ok(), expected_result.testPassed);
}

"ML-DSA-87" => {
let valid = ml_dsa_87::verify_internal(
&MLDSAVerificationKey(kat.pk.clone().try_into().unwrap()),
&test.message,
&MLDSASignature(test.signature.try_into().unwrap()),
);
assert_eq!(valid.is_ok(), expected_result.testPassed);
}
_ => unimplemented!(),
}
sigver_inner(test, &results, kat.tgId, &kat.pk, &parameter_set);
}
}
}

#[inline(never)]
#[allow(non_snake_case)]
fn sigver_inner(
test: SigVerTest,
results: &Results<ResultSigVerTestGroup>,
tgId: usize,
pk: &[u8],
parameter_set: &String,
) {
use libcrux_ml_dsa::*;
eprintln!(" {}", test.tcId);
let expected_result = results
.testGroups
.iter()
.find(|tg| tg.tgId == tgId)
.unwrap()
.tests
.iter()
.find(|t| t.tcId == test.tcId)
.unwrap();

match parameter_set.as_str() {
"ML-DSA-44" => {
let valid = ml_dsa_44::verify_internal(
&MLDSAVerificationKey(pk.to_owned().try_into().unwrap()),
&test.message,
&MLDSASignature(test.signature.try_into().unwrap()),
);
assert_eq!(valid.is_ok(), expected_result.testPassed);
}

"ML-DSA-65" => {
let valid = ml_dsa_65::verify_internal(
&MLDSAVerificationKey(pk.to_owned().try_into().unwrap()),
&test.message,
&MLDSASignature(test.signature.try_into().unwrap()),
);
assert_eq!(valid.is_ok(), expected_result.testPassed);
}

"ML-DSA-87" => {
let valid = ml_dsa_87::verify_internal(
&MLDSAVerificationKey(pk.to_owned().try_into().unwrap()),
&test.message,
&MLDSASignature(test.signature.try_into().unwrap()),
);
assert_eq!(valid.is_ok(), expected_result.testPassed);
}
_ => unimplemented!(),
}
}

Expand Down

0 comments on commit 30e98da

Please sign in to comment.