Skip to content

Commit

Permalink
Implement AES for riscv64 using Zvkned vector crypto extensions (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanshade committed Jan 21, 2024
1 parent e84d3c9 commit fa5ff07
Show file tree
Hide file tree
Showing 15 changed files with 939 additions and 181 deletions.
3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,3 @@ members = [
"twofish",
"threefish",
]

[profile.dev]
opt-level = 2
28 changes: 17 additions & 11 deletions aes/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
#![cfg_attr(
all(
target_arch = "riscv64",
target_feature = "zknd",
target_feature = "zkne"
),
feature(riscv_ext_intrinsics, stdsimd)
)]

//! Pure Rust implementation of the [Advanced Encryption Standard][AES]
//! (AES, a.k.a. Rijndael).
//!
Expand Down Expand Up @@ -127,6 +118,18 @@
)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs, rust_2018_idioms)]
#![cfg_attr(
all(target_arch = "riscv64", target_feature = "experimental-zvkned",),
feature(riscv_ext_intrinsics, stdsimd)
)]
#![cfg_attr(
all(
target_arch = "riscv64",
target_feature = "zknd",
target_feature = "zkne"
),
feature(riscv_ext_intrinsics, stdsimd)
)]

#[cfg(feature = "hazmat")]
#[cfg_attr(docsrs, doc(cfg(feature = "hazmat")))]
Expand All @@ -141,9 +144,12 @@ cfg_if! {
mod armv8;
mod autodetect;
pub use autodetect::*;
} else if #[cfg(all(target_arch = "riscv64", target_feature = "v", riscv_rvv_zvkned))] {
mod riscv;
pub use riscv::rv64::vector::*;
} else if #[cfg(all(target_arch = "riscv64", target_feature = "zknd", target_feature = "zkne"))] {
mod riscv64;
pub use riscv64::*;
mod riscv;
pub use riscv::rv64::scalar::*;
} else if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
not(aes_force_soft)
Expand Down
104 changes: 104 additions & 0 deletions aes/src/riscv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
pub(crate) mod rv64;

#[cfg(test)]
mod test {
use hex_literal::hex;

pub(crate) const AES128_KEY: [u8; 16] = hex!("2b7e151628aed2a6abf7158809cf4f3c");
pub(crate) const AES128_EXP_KEYS: [[u8; 16]; 11] = [
AES128_KEY,
hex!("a0fafe1788542cb123a339392a6c7605"),
hex!("f2c295f27a96b9435935807a7359f67f"),
hex!("3d80477d4716fe3e1e237e446d7a883b"),
hex!("ef44a541a8525b7fb671253bdb0bad00"),
hex!("d4d1c6f87c839d87caf2b8bc11f915bc"),
hex!("6d88a37a110b3efddbf98641ca0093fd"),
hex!("4e54f70e5f5fc9f384a64fb24ea6dc4f"),
hex!("ead27321b58dbad2312bf5607f8d292f"),
hex!("ac7766f319fadc2128d12941575c006e"),
hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"),
];
pub(crate) const AES128_EXP_INVKEYS: [[u8; 16]; 11] = [
AES128_KEY,
hex!("2b3708a7f262d405bc3ebdbf4b617d62"),
hex!("cc7505eb3e17d1ee82296c51c9481133"),
hex!("7c1f13f74208c219c021ae480969bf7b"),
hex!("90884413d280860a12a128421bc89739"),
hex!("6ea30afcbc238cf6ae82a4b4b54a338d"),
hex!("6efcd876d2df54807c5df034c917c3b9"),
hex!("12c07647c01f22c7bc42d2f37555114a"),
hex!("df7d925a1f62b09da320626ed6757324"),
hex!("0c7b5a631319eafeb0398890664cfbb4"),
hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"),
];

pub(crate) const AES192_KEY: [u8; 24] =
hex!("8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b");
pub(crate) const AES192_EXP_KEYS: [[u8; 16]; 13] = [
hex!("8e73b0f7da0e6452c810f32b809079e5"),
hex!("62f8ead2522c6b7bfe0c91f72402f5a5"),
hex!("ec12068e6c827f6b0e7a95b95c56fec2"),
hex!("4db7b4bd69b5411885a74796e92538fd"),
hex!("e75fad44bb095386485af05721efb14f"),
hex!("a448f6d94d6dce24aa326360113b30e6"),
hex!("a25e7ed583b1cf9a27f939436a94f767"),
hex!("c0a69407d19da4e1ec1786eb6fa64971"),
hex!("485f703222cb8755e26d135233f0b7b3"),
hex!("40beeb282f18a2596747d26b458c553e"),
hex!("a7e1466c9411f1df821f750aad07d753"),
hex!("ca4005388fcc5006282d166abc3ce7b5"),
hex!("e98ba06f448c773c8ecc720401002202"),
];
pub(crate) const AES192_EXP_INVKEYS: [[u8; 16]; 13] = [
hex!("8e73b0f7da0e6452c810f32b809079e5"),
hex!("9eb149c479d69c5dfeb4a27ceab6d7fd"),
hex!("659763e78c817087123039436be6a51e"),
hex!("41b34544ab0592b9ce92f15e421381d9"),
hex!("5023b89a3bc51d84d04b19377b4e8b8e"),
hex!("b5dc7ad0f7cffb09a7ec43939c295e17"),
hex!("c5ddb7f8be933c760b4f46a6fc80bdaf"),
hex!("5b6cfe3cc745a02bf8b9a572462a9904"),
hex!("4d65dfa2b1e5620dea899c312dcc3c1a"),
hex!("f3b42258b59ebb5cf8fb64fe491e06f3"),
hex!("a3979ac28e5ba6d8e12cc9e654b272ba"),
hex!("ac491644e55710b746c08a75c89b2cad"),
hex!("e98ba06f448c773c8ecc720401002202"),
];

pub(crate) const AES256_KEY: [u8; 32] =
hex!("603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4");
pub(crate) const AES256_EXP_KEYS: [[u8; 16]; 15] = [
hex!("603deb1015ca71be2b73aef0857d7781"),
hex!("1f352c073b6108d72d9810a30914dff4"),
hex!("9ba354118e6925afa51a8b5f2067fcde"),
hex!("a8b09c1a93d194cdbe49846eb75d5b9a"),
hex!("d59aecb85bf3c917fee94248de8ebe96"),
hex!("b5a9328a2678a647983122292f6c79b3"),
hex!("812c81addadf48ba24360af2fab8b464"),
hex!("98c5bfc9bebd198e268c3ba709e04214"),
hex!("68007bacb2df331696e939e46c518d80"),
hex!("c814e20476a9fb8a5025c02d59c58239"),
hex!("de1369676ccc5a71fa2563959674ee15"),
hex!("5886ca5d2e2f31d77e0af1fa27cf73c3"),
hex!("749c47ab18501ddae2757e4f7401905a"),
hex!("cafaaae3e4d59b349adf6acebd10190d"),
hex!("fe4890d1e6188d0b046df344706c631e"),
];
pub(crate) const AES256_EXP_INVKEYS: [[u8; 16]; 15] = [
hex!("603deb1015ca71be2b73aef0857d7781"),
hex!("8ec6bff6829ca03b9e49af7edba96125"),
hex!("42107758e9ec98f066329ea193f8858b"),
hex!("4a7459f9c8e8f9c256a156bc8d083799"),
hex!("6c3d632985d1fbd9e3e36578701be0f3"),
hex!("54fb808b9c137949cab22ff547ba186c"),
hex!("25ba3c22a06bc7fb4388a28333934270"),
hex!("d669a7334a7ade7a80c8f18fc772e9e3"),
hex!("c440b289642b757227a3d7f114309581"),
hex!("32526c367828b24cf8e043c33f92aa20"),
hex!("34ad1e4450866b367725bcc763152946"),
hex!("b668b621ce40046d36a047ae0932ed8e"),
hex!("57c96cf6074f07c0706abb07137f9241"),
hex!("ada23f4963e23b2455427c8a5c709104"),
hex!("fe4890d1e6188d0b046df344706c631e"),
];
}
8 changes: 8 additions & 0 deletions aes/src/riscv/rv64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#[cfg(all(
target_arch = "riscv64",
target_feature = "zknd",
target_feature = "zkne"
))]
pub(crate) mod scalar;
#[cfg(all(target_arch = "riscv64", target_feature = "v", riscv_rvv_zvkned))]
pub(crate) mod vector;
4 changes: 2 additions & 2 deletions aes/src/riscv64.rs → aes/src/riscv/rv64/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
//! https://github.com/openssl/openssl/blob/master/crypto/aes/asm/aes-riscv64-zkn.pl
mod encdec;
mod expand;
pub(crate) mod expand;
#[cfg(test)]
mod test_expand;
pub(crate) mod test_expand;

use self::{
encdec::{decrypt1, decrypt8, encrypt1, encrypt8},
Expand Down
File renamed without changes.
14 changes: 10 additions & 4 deletions aes/src/riscv64/expand.rs → aes/src/riscv/rv64/scalar/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use core::{
};

// TODO(silvanshade): `WORDS` should be an associated constant once support for that is stable.
pub(super) struct KeyScheduleState<const WORDS: usize, const ROUNDS: usize> {
pub(crate) struct KeyScheduleState<const WORDS: usize, const ROUNDS: usize> {
data: [u64; WORDS],
expanded_keys: [MaybeUninit<RoundKey>; ROUNDS],
}
Expand All @@ -19,6 +19,7 @@ impl KeyScheduleState<2, 11> {
unsafe { data[0].write(user_key.add(0).read_unaligned()) };
unsafe { data[1].write(user_key.add(1).read_unaligned()) };
let mut state = Self {
// SAFETY: `data` is fully initialized.
data: unsafe { transmute(data) },
expanded_keys: unsafe { MaybeUninit::uninit().assume_init() },
};
Expand All @@ -43,7 +44,7 @@ impl KeyScheduleState<2, 11> {
}

#[inline(always)]
pub(super) fn expand_key(user_key: &[u8; 16]) -> RoundKeys<11> {
pub(crate) fn expand_key(user_key: &[u8; 16]) -> RoundKeys<11> {
let mut state = Self::load(user_key);
state.one_key_rounds::<0>();
state.one_key_rounds::<1>();
Expand All @@ -55,6 +56,7 @@ impl KeyScheduleState<2, 11> {
state.one_key_rounds::<7>();
state.one_key_rounds::<8>();
state.one_key_rounds::<9>();
// SAFETY: `state.expanded_keys` is fully initialized.
unsafe { transmute(state.expanded_keys) }
}
}
Expand All @@ -68,6 +70,7 @@ impl KeyScheduleState<3, 13> {
unsafe { data[1].write(user_key.add(1).read_unaligned()) };
unsafe { data[2].write(user_key.add(2).read_unaligned()) };
let mut state = Self {
// SAFETY: `data` is fully initialized.
data: unsafe { transmute(data) },
expanded_keys: unsafe { MaybeUninit::uninit().assume_init() },
};
Expand Down Expand Up @@ -115,7 +118,7 @@ impl KeyScheduleState<3, 13> {
}

#[inline(always)]
pub(super) fn expand_key(user_key: &[u8; 24]) -> RoundKeys<13> {
pub(crate) fn expand_key(user_key: &[u8; 24]) -> RoundKeys<13> {
let mut state = Self::load(user_key);
state.one_and_one_half_key_rounds::<0>();
state.one_and_one_half_key_rounds::<1>();
Expand All @@ -125,6 +128,7 @@ impl KeyScheduleState<3, 13> {
state.one_and_one_half_key_rounds::<5>();
state.one_and_one_half_key_rounds::<6>();
state.one_key_rounds::<7>();
// SAFETY: `state.expanded_keys` is fully initialized.
unsafe { transmute(state.expanded_keys) }
}
}
Expand All @@ -139,6 +143,7 @@ impl KeyScheduleState<4, 15> {
unsafe { data[2].write(user_key.add(2).read_unaligned()) };
unsafe { data[3].write(user_key.add(3).read_unaligned()) };
let mut state = Self {
// SAFETY: `data` is fully initialized.
data: unsafe { transmute(data) },
expanded_keys: unsafe { MaybeUninit::uninit().assume_init() },
};
Expand Down Expand Up @@ -185,7 +190,7 @@ impl KeyScheduleState<4, 15> {
}

#[inline(always)]
pub(super) fn expand_key(user_key: &[u8; 32]) -> RoundKeys<15> {
pub(crate) fn expand_key(user_key: &[u8; 32]) -> RoundKeys<15> {
let mut state = Self::load(user_key);
state.two_key_rounds::<0>();
state.two_key_rounds::<1>();
Expand All @@ -194,6 +199,7 @@ impl KeyScheduleState<4, 15> {
state.two_key_rounds::<4>();
state.two_key_rounds::<5>();
state.one_key_rounds::<6>();
// SAFETY: `state.expanded_keys` is fully initialized.
unsafe { transmute(state.expanded_keys) }
}
}
Expand Down
62 changes: 62 additions & 0 deletions aes/src/riscv/rv64/scalar/test_expand.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use super::{inv_expanded_keys, KeyScheduleState, RoundKey, RoundKeys};
use crate::riscv::test::*;

fn load_expanded_keys<const N: usize>(input: [[u8; 16]; N]) -> RoundKeys<N> {
let mut output = [RoundKey::from(<[u64; 2]>::default()); N];
for (src, dst) in input.iter().zip(output.iter_mut()) {
let ptr = src.as_ptr().cast::<u64>();
dst[0] = unsafe { ptr.add(0).read_unaligned() };
dst[1] = unsafe { ptr.add(1).read_unaligned() };
}
output
}

pub(crate) fn store_expanded_keys<const N: usize>(input: RoundKeys<N>) -> [[u8; 16]; N] {
let mut output = [[0u8; 16]; N];
for (src, dst) in input.iter().zip(output.iter_mut()) {
let b0 = src[0].to_ne_bytes();
let b1 = src[1].to_ne_bytes();
dst[00..08].copy_from_slice(&b0);
dst[08..16].copy_from_slice(&b1);
}
output
}

#[test]
fn aes128_key_expansion() {
let ek = KeyScheduleState::<2, 11>::expand_key(&AES128_KEY);
assert_eq!(store_expanded_keys(ek), AES128_EXP_KEYS);
}

#[test]
fn aes128_key_expansion_inv() {
let mut ek = load_expanded_keys(AES128_EXP_KEYS);
inv_expanded_keys(&mut ek);
assert_eq!(store_expanded_keys(ek), AES128_EXP_INVKEYS);
}

#[test]
fn aes192_key_expansion() {
let ek = KeyScheduleState::<3, 13>::expand_key(&AES192_KEY);
assert_eq!(store_expanded_keys(ek), AES192_EXP_KEYS);
}

#[test]
fn aes192_key_expansion_inv() {
let mut ek = load_expanded_keys(AES192_EXP_KEYS);
inv_expanded_keys(&mut ek);
assert_eq!(store_expanded_keys(ek), AES192_EXP_INVKEYS);
}

#[test]
fn aes256_key_expansion() {
let ek = KeyScheduleState::<4, 15>::expand_key(&AES256_KEY);
assert_eq!(store_expanded_keys(ek), AES256_EXP_KEYS);
}

#[test]
fn aes256_key_expansion_inv() {
let mut ek = load_expanded_keys(AES256_EXP_KEYS);
inv_expanded_keys(&mut ek);
assert_eq!(store_expanded_keys(ek), AES256_EXP_INVKEYS);
}
Loading

0 comments on commit fa5ff07

Please sign in to comment.