Skip to content

Commit

Permalink
feat: add keep_old_subkeys parameter for rekey operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugo Rosenkranz-Costa committed Dec 4, 2023
1 parent edcd885 commit 5687908
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 72 deletions.
2 changes: 2 additions & 0 deletions src/core/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ impl Covercrypt {
policy: &Policy,
msk: &mut MasterSecretKey,
mpk: &mut MasterPublicKey,
keep_old_subkeys: bool,
) -> Result<(), Error> {
rekey(
&mut *self.rng.lock().expect("Mutex lock failed!"),
msk,
mpk,
&policy.access_policy_to_partitions(access_policy, false)?,
keep_old_subkeys,
)
}

Expand Down
18 changes: 12 additions & 6 deletions src/core/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ pub fn keygen(
let subkey = msk
.subkeys
.get_current_revision(partition)
.ok_or(Error::KeyError("Missing master subkey".to_string()))?;
.ok_or(Error::KeyError(
"Master secret key and Policy are not in sync.".to_string(),
))?;
subkeys.create_chain_with_single_value(partition.clone(), subkey.clone());
Ok::<_, Error>(())
})?;
Expand Down Expand Up @@ -404,10 +406,11 @@ pub fn rekey(
msk: &mut MasterSecretKey,
mpk: &mut MasterPublicKey,
partitions_to_rotate: &HashSet<Partition>,
keep_old_subkeys: bool,
) -> Result<(), Error> {
let h = R25519PublicKey::from(&msk.s);
for partition in partitions_to_rotate {
// write a `get_encryption`` function in a dedicated SecretSubkey struct?
// write a `get_encryption` function in a dedicated SecretSubkey struct?
let is_hybridized = EncryptionHint::new(
msk.subkeys
.get_current_revision(partition)
Expand All @@ -416,6 +419,10 @@ pub fn rekey(
);
let (public_subkey, secret_subkey) = create_subkey_pair(rng, &h, is_hybridized);
msk.subkeys.insert(partition.clone(), secret_subkey);
if !keep_old_subkeys {
// remove all older keys for a given partition
msk.subkeys.pop_tail(partition);
}

// update public subkey if partition is not read only
if mpk.subkeys.contains_key(partition) {
Expand Down Expand Up @@ -447,20 +454,19 @@ pub fn refresh(
usk.subkeys.retain_keys(msk.subkeys.keys().collect());

for (partition, user_chain) in usk.subkeys.iter_mut() {
let mut master_chain = msk.subkeys.iter_chain(partition);
let mut master_chain = msk.subkeys.iter_chain(partition).expect("at least one key");

// Remove all but the most recent subkey for this partition
if !keep_old_rights {
// find the most recent subkey between the master and user key
let master_first_key = master_chain.next().expect("have one key");
let master_first_key = master_chain.next().expect("at least one key");
if Some(master_first_key) != user_chain.head.as_ref().map(|item| &item.data) {
// new key
let new_element = Box::new(Element::new(master_first_key.clone()));
user_chain.head.replace(new_element);
}
// remove older keys if any and update length
let _ = user_chain.head.as_mut().expect("have one key").next.take();
user_chain.length = 1;
user_chain.pop_tail();
// skip to next partition
continue;
}
Expand Down
92 changes: 36 additions & 56 deletions src/data_struct/revision_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::{
},
fmt::Debug,
hash::Hash,
iter,
};

/// a `VersionedMap` stores linked lists.
Expand All @@ -17,7 +16,6 @@ where
V: Debug,
{
pub(crate) map: HashMap<K, LinkedList<V>>,
length: usize,
}

impl<K, V> RevisionMap<K, V>
Expand All @@ -28,23 +26,21 @@ where
pub fn new() -> Self {
Self {
map: HashMap::new(),
length: 0,
}
}

pub fn with_capacity(capacity: usize) -> Self {
Self {
map: HashMap::with_capacity(capacity),
length: 0,
}
}

pub fn len(&self) -> usize {
self.length
self.map.values().map(|chain| chain.len()).sum()
}

pub fn is_empty(&self) -> bool {
self.length == 0
self.len() == 0
}

pub fn nb_chains(&self) -> usize {
Expand All @@ -68,9 +64,6 @@ where

/// Inserts value at the front of the chain for a given key
pub fn insert(&mut self, key: K, value: V) {
// All branches will add an element in the map.
self.length += 1;

match self.map.entry(key) {
Entry::Occupied(entry) => Self::insert_in_chain(entry, value),
Entry::Vacant(entry) => Self::insert_new_chain(entry, value),
Expand Down Expand Up @@ -113,51 +106,32 @@ where

/// Iterates through all revisions of a given key starting with the more
/// recent one.
pub fn iter_chain<'a, Q>(&'a self, key: &Q) -> Box<dyn 'a + Iterator<Item = &V>>
pub fn iter_chain<Q>(&self, key: &Q) -> Option<impl Iterator<Item = &V>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
match self.map.get(key) {
Some(chain) => Box::new(chain.iter()),
None => Box::new(iter::empty()),
}
self.map.get(key).map(LinkedList::iter)
}

/// Removes and returns an iterator over all revisions from a given key.
pub fn remove_chain<'a, Q>(&'a mut self, key: &Q) -> Box<dyn 'a + Iterator<Item = V>>
pub fn remove_chain<Q>(&mut self, key: &Q) -> Option<impl Iterator<Item = V>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
match self.map.remove(key) {
Some(chain) => {
self.length -= chain.len();
Box::new(chain.into_iter())
}
None => Box::new(iter::empty()),
}
self.map.remove(key).map(LinkedList::into_iter)
}

/// Removes and returns the older revision from a given key.
pub fn remove_older_revision(&mut self, key: &K) -> Option<V> {
let Entry::Occupied(mut entry) = self.map.entry(key.clone()) else {
return None;
};
let chain = entry.get_mut();
let removed_entry = chain.pop_back();

// remove linked list if the last revision was removed
if chain.is_empty() {
entry.remove_entry();
}

// update map length
if removed_entry.is_some() {
self.length -= 1;
}

removed_entry
/// Removes and returns the older revisions from a given key.
pub fn pop_tail<Q>(&mut self, key: &Q) -> Option<impl Iterator<Item = V>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.map
.get_mut(key)
.map(|chain| chain.split_off(1).into_iter())
}

pub fn retain_keys(&mut self, keys: HashSet<&K>) {
Expand All @@ -182,42 +156,48 @@ mod tests {
assert!(map.is_empty());

// Insertions
map.insert("Part1".to_string(), "Rotation1".to_string());
map.insert("Part1".to_string(), "Part1V1".to_string());
assert_eq!(map.map.len(), 1);
map.insert("Part1".to_string(), "Rotation2".to_string());
map.insert("Part1".to_string(), "Part1V2".to_string());
assert_eq!(map.len(), 2);
// the inner map only has 1 entry with 2 revisions
assert_eq!(map.map.len(), 1);

map.insert("Part2".to_string(), "New".to_string());
map.insert("Part2".to_string(), "Part2V1".to_string());
map.insert("Part2".to_string(), "Part2V2".to_string());
map.insert("Part2".to_string(), "Part2V3".to_string());
assert_eq!(map.map.len(), 2);
assert_eq!(map.len(), 5);

// Get
assert_eq!(map.get_current_revision("Part1").unwrap(), "Rotation2");
assert_eq!(map.get_current_revision("Part2").unwrap(), "New");
assert_eq!(map.get_current_revision("Part1").unwrap(), "Part1V2");
assert_eq!(map.get_current_revision("Part2").unwrap(), "Part2V3");
assert!(map.get_current_revision("Missing").is_none());

// Iterators
let vec: Vec<_> = map.iter().collect();
assert_eq!(vec.len(), 3);
assert_eq!(vec.len(), map.len());

let vec: Vec<_> = map.iter_chain("Part1").collect();
assert_eq!(vec, vec!["Rotation2", "Rotation1"]);
let vec: Vec<_> = map.iter_chain("Part1").unwrap().collect();
assert_eq!(vec, vec!["Part1V2", "Part1V1"]);

let keys_set = map.keys().collect::<HashSet<_>>();
assert!(keys_set.contains(&"Part1".to_string()));
assert!(keys_set.contains(&"Part2".to_string()));

// Remove values
assert_eq!(
map.remove_older_revision(&"Part2".to_string()).unwrap(),
"New"
);
assert_eq!(map.len(), 2);
let vec: Vec<_> = map.remove_chain("Part1").unwrap().collect();
assert_eq!(vec, vec!["Part1V2".to_string(), "Part1V1".to_string()]);
assert_eq!(map.len(), 3);
assert_eq!(map.map.len(), 1);

let vec: Vec<_> = map.remove_chain("Part1").collect();
assert_eq!(vec, vec!["Rotation2".to_string(), "Rotation1".to_string()]);
// Pop tail
let vec: Vec<_> = map.pop_tail("Part2").unwrap().collect();
assert_eq!(vec, vec!["Part2V2".to_string(), "Part2V1".to_string()]);
assert_eq!(map.len(), 1);
let vec: Vec<_> = map.remove_chain("Part2").unwrap().collect();
assert_eq!(vec, vec!["Part2V3".to_string()]);

assert!(map.is_empty());
}
}
24 changes: 14 additions & 10 deletions src/data_struct/revision_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,21 @@ use std::{
#[derive(Default, Debug, PartialEq, Eq)]
pub struct RevisionVec<K, T> {
chains: Vec<(K, RevisionList<T>)>,
length: usize,
}

impl<K, T> RevisionVec<K, T> {
pub fn new() -> Self {
Self {
chains: Vec::new(),
length: 0,
}
Self { chains: Vec::new() }
}

pub fn with_capacity(capacity: usize) -> Self {
Self {
chains: Vec::with_capacity(capacity),
length: 0,
}
}

pub fn len(&self) -> usize {
self.length
self.chains.iter().map(|(_, chain)| chain.len()).sum()
}

pub fn is_empty(&self) -> bool {
Expand All @@ -56,7 +51,6 @@ impl<K, T> RevisionVec<K, T> {
pub fn create_chain_with_single_value(&mut self, key: K, val: T) {
let mut new_chain = RevisionList::new();
new_chain.push_front(val);
self.length += 1;
self.chains.push((key, new_chain));
}

Expand All @@ -65,14 +59,12 @@ impl<K, T> RevisionVec<K, T> {
/// structure.
pub fn insert_new_chain(&mut self, key: K, new_chain: RevisionList<T>) {
if !new_chain.is_empty() {
self.length += new_chain.len();
self.chains.push((key, new_chain));
}
}

pub fn clear(&mut self) {
self.chains.clear();
self.length = self.chains.len();
}

pub fn retain_keys(&mut self, keys: HashSet<&K>)
Expand Down Expand Up @@ -190,6 +182,18 @@ impl<T> RevisionList<T> {
self.head.as_ref().map(|element| &element.data)
}

pub fn pop_tail(&mut self) -> RevisionListIter<T> {
self.length = self.head.as_ref().map_or(0, |_| 1);
match &self.head {
Some(head) => RevisionListIter {
current_element: &head.next,
},
None => RevisionListIter {
current_element: &None,
},
}
}

pub fn iter(&self) -> impl Iterator<Item = &T> {
RevisionListIter::new(self)
}
Expand Down

0 comments on commit 5687908

Please sign in to comment.