From edcd885cee62f266bed7236b05b90b9f7bcf311c Mon Sep 17 00:00:00 2001 From: Hugo Rosenkranz-Costa Date: Mon, 4 Dec 2023 11:22:42 +0100 Subject: [PATCH] refacto: simplify ways to insert new chains in `RevisionVec` --- src/core/api.rs | 4 +- src/core/primitives.rs | 54 +++++++++---------- src/core/serialization.rs | 14 ++--- src/data_struct/revision_vec.rs | 95 +++++++++++++++------------------ src/test_utils/mod.rs | 6 +-- 5 files changed, 82 insertions(+), 91 deletions(-) diff --git a/src/core/api.rs b/src/core/api.rs index ea9bb101..8db3b307 100644 --- a/src/core/api.rs +++ b/src/core/api.rs @@ -103,11 +103,11 @@ impl Covercrypt { access_policy: &AccessPolicy, policy: &Policy, ) -> Result { - Ok(keygen( + keygen( &mut *self.rng.lock().expect("Mutex lock failed!"), msk, &policy.access_policy_to_partitions(access_policy, true)?, - )) + ) } /// Refreshes the user key according to the given master key and user diff --git a/src/core/primitives.rs b/src/core/primitives.rs index 793efa12..a17cf2b6 100644 --- a/src/core/primitives.rs +++ b/src/core/primitives.rs @@ -212,19 +212,19 @@ pub fn keygen( rng: &mut impl CryptoRngCore, msk: &MasterSecretKey, decryption_set: &HashSet, -) -> UserSecretKey { +) -> Result { let a = R25519PrivateKey::new(rng); let b = &(&msk.s - &(&a * &msk.s1)) / &msk.s2; // Use the last key for each partitions in the decryption set - // TODO: error out if missing partitions? - let subkeys: RevisionVec<_, _> = decryption_set - .iter() - .filter_map(|partition| { - msk.subkeys - .get_current_revision(partition) - .map(|subkey| (partition.clone(), subkey.clone())) - }) - .collect(); + let mut subkeys = RevisionVec::with_capacity(decryption_set.len()); + decryption_set.iter().try_for_each(|partition| { + let subkey = msk + .subkeys + .get_current_revision(partition) + .ok_or(Error::KeyError("Missing master subkey".to_string()))?; + subkeys.create_chain_with_single_value(partition.clone(), subkey.clone()); + Ok::<_, Error>(()) + })?; let mut usk = UserSecretKey { a, @@ -233,7 +233,7 @@ pub fn keygen( kmac: None, }; usk.kmac = compute_user_key_kmac(msk, &usk); - usk + Ok(usk) } /// Generates a `Covercrypt` encapsulation of a random symmetric key. @@ -447,24 +447,26 @@ pub fn refresh( usk.subkeys.retain_keys(msk.subkeys.keys().collect()); for (partition, user_chain) in usk.subkeys.iter_mut() { - let mut master_chain = msk.subkeys.iter_chain(&partition); + let mut master_chain = msk.subkeys.iter_chain(partition); + // Remove all but the most recent subkey for this partition if !keep_old_rights { - // find the most recent key between the master and user key - let master_first_key = master_chain.next().expect("at least one key"); + // find the most recent subkey between the master and user key + let master_first_key = master_chain.next().expect("have one key"); if Some(master_first_key) != user_chain.head.as_ref().map(|item| &item.data) { // new key let new_element = Box::new(Element::new(master_first_key.clone())); user_chain.head.replace(new_element); } - // remove older keys if any - let _ = user_chain.head.as_mut().unwrap().next.take(); - // skip to next chain + // remove older keys if any and update length + let _ = user_chain.head.as_mut().expect("have one key").next.take(); + user_chain.length = 1; + // skip to next partition continue; } // 1 - add new master subkeys in user key if any - let user_first_key = user_chain.head.take().expect("at least one key"); + let user_first_key = user_chain.head.take().expect("have one key"); let mut insertion_cursor = &mut user_chain.head; let mut updated_chain_length = 0; @@ -508,8 +510,6 @@ pub fn refresh( #[cfg(test)] mod tests { - use std::iter; - use cosmian_crypto_core::{ bytes_ser_de::Serializable, reexport::rand_core::SeedableRng, CsRng, }; @@ -563,8 +563,8 @@ mod tests { assert!(dev_secret_subkeys.unwrap().0.is_none()); // Generate user secret keys. - let mut dev_usk = keygen(&mut rng, &msk, &users_set[0]); - let admin_usk = keygen(&mut rng, &msk, &users_set[1]); + let mut dev_usk = keygen(&mut rng, &msk, &users_set[0])?; + let admin_usk = keygen(&mut rng, &msk, &users_set[1])?; // Encapsulate key for the admin target set. let (sym_key, encapsulation) = encaps(&mut rng, &mpk, &admin_target_set).unwrap(); @@ -645,7 +645,7 @@ mod tests { ); // Client is able to decapsulate. - let client_usk = keygen(&mut rng, &msk, &HashSet::from([client_partition])); + let client_usk = keygen(&mut rng, &msk, &HashSet::from([client_partition]))?; let res0 = decaps(&client_usk, &new_encapsulation); match res0 { Err(err) => panic!("Client should be able to decapsulate: {err:?}"), @@ -726,7 +726,7 @@ mod tests { &mut rng, &msk, &HashSet::from([partition_1.clone(), partition_2.clone()]), - ); + )?; // now remove partition 1 and add partition 4 let partition_4 = Partition(b"4".to_vec()); @@ -797,16 +797,16 @@ mod tests { // setup scheme let (msk, _) = setup(&mut rng, &partitions_set); // create a user key with access to partition 1 and 2 - let mut usk = keygen(&mut rng, &msk, &HashSet::from([partition_1, partition_2])); + let mut usk = keygen(&mut rng, &msk, &HashSet::from([partition_1, partition_2]))?; assert!(verify_user_key_kmac(&msk, &usk).is_ok()); let bytes = usk.serialize()?; let usk_ = UserSecretKey::deserialize(&bytes)?; assert!(verify_user_key_kmac(&msk, &usk_).is_ok()); - usk.subkeys.insert_new_chain( + usk.subkeys.create_chain_with_single_value( Partition(b"3".to_vec()), - iter::once((None, R25519PrivateKey::new(&mut rng))), + (None, R25519PrivateKey::new(&mut rng)), ); // KMAC verify will fail after modifying the user key assert!(verify_user_key_kmac(&msk, &usk).is_err()); diff --git a/src/core/serialization.rs b/src/core/serialization.rs index 609dc863..4c43b13c 100644 --- a/src/core/serialization.rs +++ b/src/core/serialization.rs @@ -15,7 +15,7 @@ use crate::{ Encapsulation, KeyEncapsulation, MasterPublicKey, MasterSecretKey, UserSecretKey, SYM_KEY_LENGTH, }, - data_struct::{RevisionMap, RevisionVec}, + data_struct::{RevisionList, RevisionMap, RevisionVec}, CleartextHeader, EncryptedHeader, Error, }; @@ -186,7 +186,7 @@ impl Serializable for UserSecretKey { for (partition, chain) in self.subkeys.iter() { length += to_leb128_len(partition.len()) + partition.len(); length += to_leb128_len(chain.len()); - for (_, (sk_i, _)) in chain.iter() { + for (sk_i, _) in chain.iter() { length += serialize_len_option!(sk_i, _value, KYBER_INDCPA_SECRETKEYBYTES); } } @@ -202,7 +202,7 @@ impl Serializable for UserSecretKey { n += ser.write_vec(partition)?; // iterate through all subkeys in the chain n += ser.write_leb128_u64(chain.len() as u64)?; - for (_, (sk_i, x_i)) in chain.iter() { + for (sk_i, x_i) in chain.iter() { serialize_option!(ser, n, sk_i, value, ser.write_array(value)); n += ser.write_array(&x_i.to_bytes())?; } @@ -222,14 +222,14 @@ impl Serializable for UserSecretKey { let partition = Partition::from(de.read_vec()?); // read all keys forming a chain and inserting them all at once. let n_keys = ::try_from(de.read_leb128_u64()?)?; - let it = (0..n_keys) + let new_chain: Result, _> = (0..n_keys) .map(|_| { let sk_i = deserialize_option!(de, KyberSecretKey(de.read_array()?)); let x_i = de.read_array::<{ R25519PrivateKey::LENGTH }>()?; Ok::<_, Self::Error>((sk_i, R25519PrivateKey::try_from_bytes(x_i)?)) }) - .filter_map(Result::ok); - subkeys.insert_new_chain(partition, it); + .collect(); + subkeys.insert_new_chain(partition, new_chain?); } let kmac = de.read_array::<{ KMAC_LENGTH }>().ok(); @@ -455,7 +455,7 @@ mod tests { assert_eq!(mpk, mpk_, "Wrong `PublicKey` derserialization."); // Check Covercrypt `UserSecretKey` serialization. - let usk = keygen(&mut rng, &msk, &user_set); + let usk = keygen(&mut rng, &msk, &user_set)?; let bytes = usk.serialize()?; assert_eq!(bytes.len(), usk.length(), "Wrong user secret key size"); let usk_ = UserSecretKey::deserialize(&bytes)?; diff --git a/src/data_struct/revision_vec.rs b/src/data_struct/revision_vec.rs index 5b927184..66dcf96e 100644 --- a/src/data_struct/revision_vec.rs +++ b/src/data_struct/revision_vec.rs @@ -1,7 +1,6 @@ use std::{ collections::{HashSet, VecDeque}, hash::Hash, - iter, }; /// a `RevisionVec` stores for each entry a linked list of versions. @@ -20,7 +19,7 @@ use std::{ // TODO: does index matter for Eq compare? #[derive(Default, Debug, PartialEq, Eq)] pub struct RevisionVec { - chains: Vec>, + chains: Vec<(K, RevisionList)>, length: usize, } @@ -51,12 +50,23 @@ impl RevisionVec { self.chains.len() } - /// Insert new chain entries in order of arrival - pub fn insert_new_chain(&mut self, key: K, iter: impl Iterator) { - let new_chain = RevisionList::from_iter(key, iter); + /// Creates and insert a new chain with a single value. + /// /!\ Adding multiple chains with the same key will corrupt the data + /// structure. + pub fn create_chain_with_single_value(&mut self, key: K, val: T) { + let mut new_chain = RevisionList::new(); + new_chain.push_front(val); + self.length += 1; + self.chains.push((key, new_chain)); + } + + /// Inserts a new chain with a corresponding key. + /// /!\ Adding multiple chains with the same key will corrupt the data + /// structure. + pub fn insert_new_chain(&mut self, key: K, new_chain: RevisionList) { if !new_chain.is_empty() { self.length += new_chain.len(); - self.chains.push(new_chain); + self.chains.push((key, new_chain)); } } @@ -69,29 +79,28 @@ impl RevisionVec { where K: Hash + Eq, { - self.chains - .retain(|chain: &RevisionList| keys.contains(&chain.key)); + self.chains.retain(|(key, _)| keys.contains(key)); } /// Returns an iterator over each key-chains pair - pub fn iter(&self) -> impl Iterator)> { - self.chains.iter().map(|chain| (&chain.key, chain)) + pub fn iter(&self) -> impl Iterator)> { + self.chains.iter().map(|(key, chain)| (key, chain)) } /// Returns an iterator over each key-chains pair that allow modifying chain - pub fn iter_mut(&mut self) -> impl Iterator)> + pub fn iter_mut(&mut self) -> impl Iterator)> where K: Clone, { - self.chains - .iter_mut() - .map(|chain| (chain.key.clone(), chain)) + self.chains.iter_mut().map(|(ref key, chain)| (key, chain)) } /// Iterates through all versions of all entries /// Returns the key and value for each entry. pub fn flat_iter(&self) -> impl Iterator { - self.chains.iter().flat_map(|chain| chain.iter()) + self.chains + .iter() + .flat_map(|(key, chain)| chain.iter().map(move |val| (key, val))) } pub fn bfs(&self) -> BfsIterator { @@ -110,7 +119,7 @@ impl<'a, T> BfsIterator<'a, T> { queue: revision_vec .chains .iter() - .filter_map(|chain| Some(chain.head.as_ref()?.as_ref())) + .filter_map(|(_, chain)| Some(chain.head.as_ref()?.as_ref())) .collect(), } } @@ -130,20 +139,6 @@ impl<'a, T> Iterator for BfsIterator<'a, T> { } } -/// Create `RevisionVec` from an iterator, each element will be inserted in a -/// different chain. Use `insert_new_chain` to collect an iterator inside the -/// same chain. -impl FromIterator<(K, T)> for RevisionVec { - fn from_iter>(iter: I) -> Self { - let iterator = iter.into_iter(); - let mut vec = Self::with_capacity(iterator.size_hint().0); - for (key, item) in iterator { - vec.insert_new_chain(key, iter::once(item)) - } - vec - } -} - #[derive(Default, Debug, PartialEq, Eq)] pub struct Element { pub(crate) data: T, @@ -160,16 +155,14 @@ impl Element { } #[derive(Default, Debug, PartialEq, Eq)] -pub struct RevisionList { - key: K, +pub struct RevisionList { pub(crate) length: usize, pub(crate) head: Option>>, } -impl RevisionList { - pub fn new(key: K) -> Self { +impl RevisionList { + pub fn new() -> Self { Self { - key, length: 0, head: None, } @@ -197,29 +190,31 @@ impl RevisionList { self.head.as_ref().map(|element| &element.data) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { RevisionListIter::new(self) } +} +impl FromIterator for RevisionList { /// Creates a `RevisionList` from an iterator by inserting elements in the /// order of arrival: first item in the iterator will end up at the front. - pub fn from_iter(key: K, mut iter: impl Iterator) -> Self { - if let Some(first_element) = iter.next() { + fn from_iter>(iter: I) -> Self { + let mut iterator = iter.into_iter(); + if let Some(first_element) = iterator.next() { let mut length = 1; let mut head = Some(Box::new(Element::new(first_element))); - let mut current_element = head.as_mut().expect("next element was inserted above"); - for next_item in iter { + let mut current_element = head.as_mut().expect("element was inserted above"); + for next_item in iterator { current_element.next = Some(Box::new(Element::new(next_item))); current_element = current_element .next .as_mut() - .expect("next element was inserted above"); + .expect("element was inserted above"); length += 1; } - Self { key, length, head } + Self { length, head } } else { Self { - key, length: 0, head: None, } @@ -227,27 +222,25 @@ impl RevisionList { } } -pub struct RevisionListIter<'a, K, T> { - key: &'a K, +pub struct RevisionListIter<'a, T> { current_element: &'a Option>>, } -impl<'a, K, T> RevisionListIter<'a, K, T> { - pub fn new(rev_list: &'a RevisionList) -> Self { +impl<'a, T> RevisionListIter<'a, T> { + pub fn new(rev_list: &'a RevisionList) -> Self { Self { - key: &rev_list.key, current_element: &rev_list.head, } } } -impl<'a, K, T> Iterator for RevisionListIter<'a, K, T> { - type Item = (&'a K, &'a T); +impl<'a, T> Iterator for RevisionListIter<'a, T> { + type Item = &'a T; fn next(&mut self) -> Option { let element = self.current_element.as_ref()?; self.current_element = &element.next; - Some((self.key, &element.data)) + Some(&element.data) } } diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index dd2d5e0a..35199863 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -36,8 +36,6 @@ pub fn policy() -> Result { #[cfg(test)] mod tests { - use std::iter; - use cosmian_crypto_core::bytes_ser_de::Serializable; use super::*; @@ -188,9 +186,9 @@ mod tests { // try to modify the user key and refresh let part = Partition::from(vec![1, 6]); - usk.subkeys.insert_new_chain( + usk.subkeys.create_chain_with_single_value( part.clone(), - iter::once(msk.subkeys.get_current_revision(&part).unwrap().clone()), + msk.subkeys.get_current_revision(&part).unwrap().clone(), ); assert!(cover_crypt .refresh_user_secret_key(&mut usk, &msk, false)