diff --git a/src/core/api.rs b/src/core/api.rs index 8db3b307..895a5c9b 100644 --- a/src/core/api.rs +++ b/src/core/api.rs @@ -81,12 +81,14 @@ impl Covercrypt { policy: &Policy, msk: &mut MasterSecretKey, mpk: &mut MasterPublicKey, + keep_old_subkeys: bool, ) -> Result<(), Error> { rekey( &mut *self.rng.lock().expect("Mutex lock failed!"), msk, mpk, &policy.access_policy_to_partitions(access_policy, false)?, + keep_old_subkeys, ) } diff --git a/src/core/primitives.rs b/src/core/primitives.rs index a17cf2b6..f80fd907 100644 --- a/src/core/primitives.rs +++ b/src/core/primitives.rs @@ -221,7 +221,9 @@ pub fn keygen( let subkey = msk .subkeys .get_current_revision(partition) - .ok_or(Error::KeyError("Missing master subkey".to_string()))?; + .ok_or(Error::KeyError( + "Master secret key and Policy are not in sync.".to_string(), + ))?; subkeys.create_chain_with_single_value(partition.clone(), subkey.clone()); Ok::<_, Error>(()) })?; @@ -404,10 +406,11 @@ pub fn rekey( msk: &mut MasterSecretKey, mpk: &mut MasterPublicKey, partitions_to_rotate: &HashSet, + keep_old_subkeys: bool, ) -> Result<(), Error> { let h = R25519PublicKey::from(&msk.s); for partition in partitions_to_rotate { - // write a `get_encryption`` function in a dedicated SecretSubkey struct? + // write a `get_encryption` function in a dedicated SecretSubkey struct? let is_hybridized = EncryptionHint::new( msk.subkeys .get_current_revision(partition) @@ -416,6 +419,10 @@ pub fn rekey( ); let (public_subkey, secret_subkey) = create_subkey_pair(rng, &h, is_hybridized); msk.subkeys.insert(partition.clone(), secret_subkey); + if !keep_old_subkeys { + // remove all older keys for a given partition + msk.subkeys.pop_tail(partition); + } // update public subkey if partition is not read only if mpk.subkeys.contains_key(partition) { @@ -447,20 +454,19 @@ pub fn refresh( usk.subkeys.retain_keys(msk.subkeys.keys().collect()); for (partition, user_chain) in usk.subkeys.iter_mut() { - let mut master_chain = msk.subkeys.iter_chain(partition); + let mut master_chain = msk.subkeys.iter_chain(partition).expect("at least one key"); // Remove all but the most recent subkey for this partition if !keep_old_rights { // find the most recent subkey between the master and user key - let master_first_key = master_chain.next().expect("have one key"); + let master_first_key = master_chain.next().expect("at least one key"); if Some(master_first_key) != user_chain.head.as_ref().map(|item| &item.data) { // new key let new_element = Box::new(Element::new(master_first_key.clone())); user_chain.head.replace(new_element); } // remove older keys if any and update length - let _ = user_chain.head.as_mut().expect("have one key").next.take(); - user_chain.length = 1; + user_chain.pop_tail(); // skip to next partition continue; } diff --git a/src/data_struct/revision_map.rs b/src/data_struct/revision_map.rs index e1f5a572..7e439df5 100644 --- a/src/data_struct/revision_map.rs +++ b/src/data_struct/revision_map.rs @@ -6,7 +6,6 @@ use std::{ }, fmt::Debug, hash::Hash, - iter, }; /// a `VersionedMap` stores linked lists. @@ -17,7 +16,6 @@ where V: Debug, { pub(crate) map: HashMap>, - length: usize, } impl RevisionMap @@ -28,23 +26,21 @@ where pub fn new() -> Self { Self { map: HashMap::new(), - length: 0, } } pub fn with_capacity(capacity: usize) -> Self { Self { map: HashMap::with_capacity(capacity), - length: 0, } } pub fn len(&self) -> usize { - self.length + self.map.values().map(|chain| chain.len()).sum() } pub fn is_empty(&self) -> bool { - self.length == 0 + self.len() == 0 } pub fn nb_chains(&self) -> usize { @@ -68,9 +64,6 @@ where /// Inserts value at the front of the chain for a given key pub fn insert(&mut self, key: K, value: V) { - // All branches will add an element in the map. - self.length += 1; - match self.map.entry(key) { Entry::Occupied(entry) => Self::insert_in_chain(entry, value), Entry::Vacant(entry) => Self::insert_new_chain(entry, value), @@ -113,51 +106,32 @@ where /// Iterates through all revisions of a given key starting with the more /// recent one. - pub fn iter_chain<'a, Q>(&'a self, key: &Q) -> Box> + pub fn iter_chain(&self, key: &Q) -> Option> where K: Borrow, Q: Hash + Eq + ?Sized, { - match self.map.get(key) { - Some(chain) => Box::new(chain.iter()), - None => Box::new(iter::empty()), - } + self.map.get(key).map(LinkedList::iter) } /// Removes and returns an iterator over all revisions from a given key. - pub fn remove_chain<'a, Q>(&'a mut self, key: &Q) -> Box> + pub fn remove_chain(&mut self, key: &Q) -> Option> where K: Borrow, Q: Hash + Eq + ?Sized, { - match self.map.remove(key) { - Some(chain) => { - self.length -= chain.len(); - Box::new(chain.into_iter()) - } - None => Box::new(iter::empty()), - } + self.map.remove(key).map(LinkedList::into_iter) } - /// Removes and returns the older revision from a given key. - pub fn remove_older_revision(&mut self, key: &K) -> Option { - let Entry::Occupied(mut entry) = self.map.entry(key.clone()) else { - return None; - }; - let chain = entry.get_mut(); - let removed_entry = chain.pop_back(); - - // remove linked list if the last revision was removed - if chain.is_empty() { - entry.remove_entry(); - } - - // update map length - if removed_entry.is_some() { - self.length -= 1; - } - - removed_entry + /// Removes and returns the older revisions from a given key. + pub fn pop_tail(&mut self, key: &Q) -> Option> + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + self.map + .get_mut(key) + .map(|chain| chain.split_off(1).into_iter()) } pub fn retain_keys(&mut self, keys: HashSet<&K>) { @@ -182,42 +156,48 @@ mod tests { assert!(map.is_empty()); // Insertions - map.insert("Part1".to_string(), "Rotation1".to_string()); + map.insert("Part1".to_string(), "Part1V1".to_string()); assert_eq!(map.map.len(), 1); - map.insert("Part1".to_string(), "Rotation2".to_string()); + map.insert("Part1".to_string(), "Part1V2".to_string()); assert_eq!(map.len(), 2); // the inner map only has 1 entry with 2 revisions assert_eq!(map.map.len(), 1); - map.insert("Part2".to_string(), "New".to_string()); + map.insert("Part2".to_string(), "Part2V1".to_string()); + map.insert("Part2".to_string(), "Part2V2".to_string()); + map.insert("Part2".to_string(), "Part2V3".to_string()); assert_eq!(map.map.len(), 2); + assert_eq!(map.len(), 5); // Get - assert_eq!(map.get_current_revision("Part1").unwrap(), "Rotation2"); - assert_eq!(map.get_current_revision("Part2").unwrap(), "New"); + assert_eq!(map.get_current_revision("Part1").unwrap(), "Part1V2"); + assert_eq!(map.get_current_revision("Part2").unwrap(), "Part2V3"); assert!(map.get_current_revision("Missing").is_none()); // Iterators let vec: Vec<_> = map.iter().collect(); - assert_eq!(vec.len(), 3); + assert_eq!(vec.len(), map.len()); - let vec: Vec<_> = map.iter_chain("Part1").collect(); - assert_eq!(vec, vec!["Rotation2", "Rotation1"]); + let vec: Vec<_> = map.iter_chain("Part1").unwrap().collect(); + assert_eq!(vec, vec!["Part1V2", "Part1V1"]); let keys_set = map.keys().collect::>(); assert!(keys_set.contains(&"Part1".to_string())); assert!(keys_set.contains(&"Part2".to_string())); // Remove values - assert_eq!( - map.remove_older_revision(&"Part2".to_string()).unwrap(), - "New" - ); - assert_eq!(map.len(), 2); + let vec: Vec<_> = map.remove_chain("Part1").unwrap().collect(); + assert_eq!(vec, vec!["Part1V2".to_string(), "Part1V1".to_string()]); + assert_eq!(map.len(), 3); assert_eq!(map.map.len(), 1); - let vec: Vec<_> = map.remove_chain("Part1").collect(); - assert_eq!(vec, vec!["Rotation2".to_string(), "Rotation1".to_string()]); + // Pop tail + let vec: Vec<_> = map.pop_tail("Part2").unwrap().collect(); + assert_eq!(vec, vec!["Part2V2".to_string(), "Part2V1".to_string()]); + assert_eq!(map.len(), 1); + let vec: Vec<_> = map.remove_chain("Part2").unwrap().collect(); + assert_eq!(vec, vec!["Part2V3".to_string()]); + assert!(map.is_empty()); } } diff --git a/src/data_struct/revision_vec.rs b/src/data_struct/revision_vec.rs index 66dcf96e..d31f4a5d 100644 --- a/src/data_struct/revision_vec.rs +++ b/src/data_struct/revision_vec.rs @@ -20,26 +20,21 @@ use std::{ #[derive(Default, Debug, PartialEq, Eq)] pub struct RevisionVec { chains: Vec<(K, RevisionList)>, - length: usize, } impl RevisionVec { pub fn new() -> Self { - Self { - chains: Vec::new(), - length: 0, - } + Self { chains: Vec::new() } } pub fn with_capacity(capacity: usize) -> Self { Self { chains: Vec::with_capacity(capacity), - length: 0, } } pub fn len(&self) -> usize { - self.length + self.chains.iter().map(|(_, chain)| chain.len()).sum() } pub fn is_empty(&self) -> bool { @@ -56,7 +51,6 @@ impl RevisionVec { pub fn create_chain_with_single_value(&mut self, key: K, val: T) { let mut new_chain = RevisionList::new(); new_chain.push_front(val); - self.length += 1; self.chains.push((key, new_chain)); } @@ -65,14 +59,12 @@ impl RevisionVec { /// structure. pub fn insert_new_chain(&mut self, key: K, new_chain: RevisionList) { if !new_chain.is_empty() { - self.length += new_chain.len(); self.chains.push((key, new_chain)); } } pub fn clear(&mut self) { self.chains.clear(); - self.length = self.chains.len(); } pub fn retain_keys(&mut self, keys: HashSet<&K>) @@ -190,6 +182,18 @@ impl RevisionList { self.head.as_ref().map(|element| &element.data) } + pub fn pop_tail(&mut self) -> RevisionListIter { + self.length = self.head.as_ref().map_or(0, |_| 1); + match &self.head { + Some(head) => RevisionListIter { + current_element: &head.next, + }, + None => RevisionListIter { + current_element: &None, + }, + } + } + pub fn iter(&self) -> impl Iterator { RevisionListIter::new(self) }