Skip to content

Commit

Permalink
mlkem incr. shared secret output
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed Jan 30, 2025
1 parent 6209660 commit 52d5492
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 61 deletions.
18 changes: 14 additions & 4 deletions libcrux-ml-kem/src/ind_cca/incremental.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ pub(crate) fn encapsulate1<
>(
pk1: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
) -> (Ciphertext1<C1_SIZE>, EncapsState<K, Vector>) {
) -> (Ciphertext1<C1_SIZE>, EncapsState<K, Vector>, [u8; 32]) {
let hashed = encaps_prepare::<K, Hasher>(&randomness, &pk1.hash);
let (shared_secret, pseudorandomness) = hashed.split_at(SHARED_SECRET_SIZE);

Expand All @@ -207,11 +207,14 @@ pub(crate) fn encapsulate1<

let state = EncapsState {
randomness,
shared_secret: shared_secret.try_into().unwrap(),
r_as_ntt,
error2,
};
(Ciphertext1 { value: ciphertext }, state)
(
Ciphertext1 { value: ciphertext },
state,
shared_secret.try_into().unwrap(),
)
}

pub(crate) fn encapsulate1_serialized<
Expand All @@ -230,8 +233,14 @@ pub(crate) fn encapsulate1_serialized<
pk1: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
shared_secret: &mut [u8],
) -> Result<Ciphertext1<C1_SIZE>, Error> {
let (ct1, encaps_state) = encapsulate1::<
debug_assert!(shared_secret.len() >= SHARED_SECRET_SIZE);
if shared_secret.len() < SHARED_SECRET_SIZE {
return Err(Error::InvalidOutputLength);
}

let (ct1, encaps_state, ss) = encapsulate1::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
Expand All @@ -247,6 +256,7 @@ pub(crate) fn encapsulate1_serialized<

// Write out the state
encaps_state.to_bytes(state)?;
shared_secret[..SHARED_SECRET_SIZE].copy_from_slice(&ss);

// Return the ciphertext
Ok(ct1)
Expand Down
25 changes: 15 additions & 10 deletions libcrux-ml-kem/src/ind_cca/incremental/multiplexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ pub(crate) mod alloc {
>(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
) -> (Ciphertext1<C1_SIZE>, Box<dyn State>) {
) -> (
Ciphertext1<C1_SIZE>,
Box<dyn State>,
[u8; SHARED_SECRET_SIZE],
) {
if libcrux_platform::simd256_support() {
let (c, s) = encapsulate1_avx2::<
let (c, s, ss) = encapsulate1_avx2::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
Expand All @@ -124,9 +128,9 @@ pub(crate) mod alloc {
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness);
(c, Box::new(s))
(c, Box::new(s), ss)
} else if libcrux_platform::simd128_support() {
let (c, s) = encapsulate1_neon::<
let (c, s, ss) = encapsulate1_neon::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
Expand All @@ -137,9 +141,9 @@ pub(crate) mod alloc {
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness);
(c, Box::new(s))
(c, Box::new(s), ss)
} else {
let (c, s) = portable::encapsulate1::<
let (c, s, ss) = portable::encapsulate1::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
Expand All @@ -150,7 +154,7 @@ pub(crate) mod alloc {
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness);
(c, Box::new(s))
(c, Box::new(s), ss)
}
}

Expand Down Expand Up @@ -355,6 +359,7 @@ pub(crate) fn encapsulate1<
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
shared_secret: &mut [u8],
) -> Result<Ciphertext1<C1_SIZE>, Error> {
if libcrux_platform::simd256_support() {
encapsulate1_serialized_avx2::<
Expand All @@ -367,7 +372,7 @@ pub(crate) fn encapsulate1<
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness, state)
>(public_key_part, randomness, state, shared_secret)
} else if libcrux_platform::simd128_support() {
encapsulate1_serialized_neon::<
K,
Expand All @@ -379,7 +384,7 @@ pub(crate) fn encapsulate1<
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness, state)
>(public_key_part, randomness, state, shared_secret)
} else {
portable::encapsulate1_serialized::<
K,
Expand All @@ -391,7 +396,7 @@ pub(crate) fn encapsulate1<
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness, state)
>(public_key_part, randomness, state, shared_secret)
}
}

Expand Down
37 changes: 11 additions & 26 deletions libcrux-ml-kem/src/ind_cca/incremental/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ impl<const LEN: usize> Ciphertext2<LEN> {

/// The incremental state for encapsulate.
pub struct EncapsState<const K: usize, Vector: Operations> {
pub(super) shared_secret: MlKemSharedSecret,
pub(super) r_as_ntt: [PolynomialRingElement<Vector>; K],
pub(super) error2: PolynomialRingElement<Vector>,
pub(super) randomness: [u8; 32],
Expand All @@ -165,29 +164,25 @@ pub struct EncapsState<const K: usize, Vector: Operations> {
impl<const K: usize, Vector: Operations> EncapsState<K, Vector> {
/// Get the number of bytes, required for the state.
pub const fn num_bytes() -> usize {
SHARED_SECRET_SIZE
+ vec_len_bytes::<K, Vector>()
+ PolynomialRingElement::<Vector>::num_bytes()
+ 32
vec_len_bytes::<K, Vector>() + PolynomialRingElement::<Vector>::num_bytes() + 32
}

#[allow(dead_code)]
/// Get the state as bytes
pub fn to_bytes(self, out: &mut [u8]) -> Result<(), Error> {
debug_assert!(out.len() >= Self::num_bytes());
if out.len() < Self::num_bytes() {
pub fn to_bytes(self, state: &mut [u8]) -> Result<(), Error> {
debug_assert!(state.len() >= Self::num_bytes());
if state.len() < Self::num_bytes() {
return Err(Error::InvalidOutputLength);
}

out[..SHARED_SECRET_SIZE].copy_from_slice(&self.shared_secret);
let mut offset = SHARED_SECRET_SIZE;

vec_to_bytes(&self.r_as_ntt, &mut out[offset..]);
let mut offset = 0;
vec_to_bytes(&self.r_as_ntt, &mut state[offset..]);
offset += vec_len_bytes::<K, Vector>();

self.error2.to_bytes(&mut out[offset..]);
self.error2.to_bytes(&mut state[offset..]);
offset += PolynomialRingElement::<Vector>::num_bytes();

out[offset..offset + 32].copy_from_slice(&self.randomness);
state[offset..offset + 32].copy_from_slice(&self.randomness);

Ok(())
}
Expand All @@ -199,10 +194,7 @@ impl<const K: usize, Vector: Operations> EncapsState<K, Vector> {
return Err(Error::InvalidInputLength);
}

let mut shared_secret = [0u8; SHARED_SECRET_SIZE];
shared_secret.copy_from_slice(&bytes[..SHARED_SECRET_SIZE]);
let mut offset = SHARED_SECRET_SIZE;

let mut offset = 0;
let mut r_as_ntt = from_fn(|_| PolynomialRingElement::<Vector>::ZERO());
vec_from_bytes(&bytes[offset..], &mut r_as_ntt);
offset += vec_len_bytes::<K, Vector>();
Expand All @@ -214,7 +206,6 @@ impl<const K: usize, Vector: Operations> EncapsState<K, Vector> {
randomness.copy_from_slice(&bytes[offset..offset + 32]);

Ok(Self {
shared_secret,
r_as_ntt,
error2,
randomness,
Expand All @@ -225,18 +216,12 @@ impl<const K: usize, Vector: Operations> EncapsState<K, Vector> {
/// Trait container for multiplexing over platform dependent [`EncapsState`].
pub trait State {
fn as_any(&self) -> &dyn Any;

/// Get the shared secret.
fn shared_secret(&self) -> &[u8];
}

impl<const K: usize, Vector: Operations + 'static> State for EncapsState<K, Vector> {
fn as_any(&self) -> &dyn Any {
self
}

fn shared_secret(&self) -> &[u8] {
&self.shared_secret
}
}

// === Implementations
Expand Down
18 changes: 11 additions & 7 deletions libcrux-ml-kem/src/mlkem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ macro_rules! impl_incr_key_size {
pub const fn encaps_state_len() -> usize {
// Because const generics are too limited, we compute it here from scratch.

// shared secret
SHARED_SECRET_SIZE
// r_as_ntt
+ RANK * 16 * 32
RANK * 16 * 32
// error2
+ 16 * 32
// randomness
Expand Down Expand Up @@ -86,7 +84,7 @@ macro_rules! impl_incr_key_size {
pub fn encapsulate1(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
) -> (Ciphertext1, Box<dyn State>) {
) -> (Ciphertext1, Box<dyn State>, [u8; SHARED_SECRET_SIZE]) {
multiplexing::alloc::encapsulate1::<
RANK,
CPA_PKE_CIPHERTEXT_SIZE,
Expand Down Expand Up @@ -171,6 +169,7 @@ macro_rules! impl_incr_key_size {
pk1: &[u8],
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
shared_secret: &mut [u8],
) -> Result<Ciphertext1, Error> {
let public_key_part = PublicKey1::try_from(&pk1 as &[u8])?;

Expand All @@ -184,7 +183,7 @@ macro_rules! impl_incr_key_size {
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(&public_key_part, randomness, state)
>(&public_key_part, randomness, state, shared_secret)
}

/// Encapsulate the second part of the ciphertext.
Expand Down Expand Up @@ -318,7 +317,11 @@ macro_rules! impl_incr_platform {
>(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
) -> (Ciphertext1<C1_SIZE>, EncapsState<K, $vector>) {
) -> (
Ciphertext1<C1_SIZE>,
EncapsState<K, $vector>,
[u8; SHARED_SECRET_SIZE],
) {
super::encapsulate1::<
K,
CIPHERTEXT_SIZE,
Expand Down Expand Up @@ -348,6 +351,7 @@ macro_rules! impl_incr_platform {
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
shared_secret: &mut [u8],
) -> Result<Ciphertext1<C1_SIZE>, Error> {
super::encapsulate1_serialized::<
K,
Expand All @@ -361,7 +365,7 @@ macro_rules! impl_incr_platform {
ETA2_RANDOMNESS_SIZE,
$vector,
$hash,
>(public_key_part, randomness, state)
>(public_key_part, randomness, state, shared_secret)
}

pub(crate) fn encapsulate2<
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-kem/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn to_i16_array<Vector: Operations>(re: PolynomialRingElement<Vector>, out: &mut
}

#[inline(always)]
#[hax_lib::requires(VECTORS_IN_RING_ELEMENT * 16 *2 <= a.len())]
#[hax_lib::requires(VECTORS_IN_RING_ELEMENT * 16 *2 <= bytes.len())]
fn from_bytes<Vector: Operations>(bytes: &[u8]) -> PolynomialRingElement<Vector> {
let mut result = ZERO();
for i in 0..VECTORS_IN_RING_ELEMENT {
Expand Down Expand Up @@ -328,7 +328,7 @@ impl<Vector: Operations> PolynomialRingElement<Vector> {
}

#[inline(always)]
#[requires(VECTORS_IN_RING_ELEMENT * 16 * 2 <= a.len())]
#[requires(VECTORS_IN_RING_ELEMENT * 16 * 2 <= bytes.len())]
pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
from_bytes(bytes)
}
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-kem/src/vector/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ pub(super) fn from_bytes(array: &[u8]) -> SIMD256Vector {
}

#[inline(always)]
#[hax_lib::requires(array.len() >= 32)]
#[hax_lib::requires(bytes.len() >= 32)]
pub(super) fn to_bytes(x: SIMD256Vector, bytes: &mut [u8]) {
mm256_storeu_si256_u8(&mut bytes[0..32], x.elements)
}
2 changes: 1 addition & 1 deletion libcrux-ml-kem/src/vector/portable/vector_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(super) fn from_bytes(array: &[u8]) -> PortableVector {
}

#[inline(always)]
#[hax_lib::requires(array.len() >= 32)]
#[hax_lib::requires(bytes.len() >= 32)]
pub(super) fn to_bytes(x: PortableVector, bytes: &mut [u8]) {
for i in 0..FIELD_ELEMENTS_IN_VECTOR {
bytes[2 * i] = (x.elements[i] >> 8) as u8;
Expand Down
6 changes: 2 additions & 4 deletions libcrux-ml-kem/tests/nistkats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ macro_rules! impl_kats {

let (ct1, ct2, incremental_shared_secret) = {
let pk1 = incremental::PublicKey1::try_from(&pk1_bytes as &[u8]).unwrap();
let (ct1, state) = incremental::alloc::encapsulate1(&pk1, kat.encapsulation_seed);
let (ct1, state, ss) = incremental::alloc::encapsulate1(&pk1, kat.encapsulation_seed);

assert!(incremental::validate_pk(&pk1, &pk2_bytes).is_ok());

Expand All @@ -188,9 +188,7 @@ macro_rules! impl_kats {
// platform dependent.
let ct2 = incremental::alloc::encapsulate2(state.as_ref(), &pk2_bytes).unwrap();

let mut shared_secret = [0u8; 32];
shared_secret.copy_from_slice(state.shared_secret());
(ct1, ct2, shared_secret)
(ct1, ct2, ss)
};

// Decapsulate
Expand Down
17 changes: 11 additions & 6 deletions libcrux-ml-kem/tests/self.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,19 @@ macro_rules! impl_consistency_incremental {
let encaps_randomness = random_array();
let (ct1, ct2, shared_secret) = {
let pk1 = PublicKey1::try_from(&pk1_bytes as &[u8]).unwrap();
let (ct1, state) = alloc::encapsulate1(&pk1, encaps_randomness);
let (ct1, state, dyn_ss) = alloc::encapsulate1(&pk1, encaps_randomness);
debug_assert_eq!(ct1.value.len(), Ciphertext1::len());

// encaps1 with serialized state
let mut serialized_state = [0u8; encaps_state_len()];
let ct12 =
encapsulate1(&pk1_bytes, encaps_randomness, &mut serialized_state).unwrap();
let mut shared_secret_serialized = [0u8; SHARED_SECRET_SIZE];
let ct12 = encapsulate1(
&pk1_bytes,
encaps_randomness,
&mut serialized_state,
&mut shared_secret_serialized,
)
.unwrap();
assert_eq!(ct1.value, ct12.value);

// Check the public key for consistency.
Expand All @@ -182,9 +188,8 @@ macro_rules! impl_consistency_incremental {
let ct22 = encapsulate2(&serialized_state, &pk2_bytes).unwrap();
assert_eq!(ct2.value, ct22.value);

let mut shared_secret = [0u8; 32];
shared_secret.copy_from_slice(state.shared_secret());
(ct1, ct2, shared_secret)
assert_eq!(dyn_ss, shared_secret_serialized);
(ct1, ct2, dyn_ss)
};

// The initiator decapsulates the two ciphertexts.
Expand Down

0 comments on commit 52d5492

Please sign in to comment.