diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 310e179..c1d2c25 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -102,7 +102,29 @@ jobs: run: cargo test --locked --release --no-run --workspace # TODO: add fuzz tests -# TODO: add benchmarks + + benchmarks: + name: Run benchmarks + runs-on: ubuntu-20.04 + steps: + - name: Checkout sources + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Install nightly toolchain + uses: dtolnay/rust-toolchain@nightly + + - name: Smoke-test benchmark program + run: cargo run -p rustls-mbedcrypto-provider --release --locked --example bench + env: + # Ensure benchmark does not take too long time + BENCH_MULTIPLIER: 0.3 + + - name: Run micro-benchmarks + run: cargo bench --locked --all-features + env: + RUSTFLAGS: --cfg=bench docs: name: Check for documentation errors diff --git a/Cargo.lock b/Cargo.lock index 7235ebb..16fdf6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,6 +77,12 @@ version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" +[[package]] +name = "bencher" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dfdb4953a096c551ce9ace855a604d702e6e62d77fac690575ae347571717f5" + [[package]] name = "bindgen" version = "0.65.1" @@ -676,6 +682,7 @@ dependencies = [ name = "rustls-mbedcrypto-provider" version = "0.0.1-alpha.1" dependencies = [ + "bencher", "bit-vec 0.6.3", "env_logger", "log", diff --git a/rustls-mbedcrypto-provider/Cargo.toml b/rustls-mbedcrypto-provider/Cargo.toml index a015620..027a888 100644 --- a/rustls-mbedcrypto-provider/Cargo.toml +++ b/rustls-mbedcrypto-provider/Cargo.toml @@ -43,6 +43,7 @@ webpki-roots = "0.26.0" rustls-pemfile = "2" env_logger = "0.10" log = { version = "0.4.4" } +bencher = "0.1.5" [features] default = ["logging", "tls12"] @@ -57,3 +58,10 @@ path = "examples/client.rs" [[example]] name = "bench" path = "examples/internal/bench.rs" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[package.metadata.cargo_check_external_types] +allowed_external_types = ["rustls_pki_types::*"] diff --git a/rustls-mbedcrypto-provider/examples/internal/bench.rs b/rustls-mbedcrypto-provider/examples/internal/bench.rs index d774c7f..f4dec9d 100644 --- a/rustls-mbedcrypto-provider/examples/internal/bench.rs +++ b/rustls-mbedcrypto-provider/examples/internal/bench.rs @@ -1,638 +1,5 @@ -/* Copyright (c) Fortanix, Inc. - * - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this - * file, You can obtain one at http://mozilla.org/MPL/2.0/. - */ - -// This program does assorted benchmarking of rustls. -// -// Note: we don't use any of the standard 'cargo bench', 'test::Bencher', -// etc. because it's unstable at the time of writing. - -use std::env; -use std::fs; -use std::io::{self, Read, Write}; -use std::ops::Deref; -use std::ops::DerefMut; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use pki_types::{CertificateDer, PrivateKeyDer}; - -use rustls::client::Resumption; -use rustls::crypto::ring::{cipher_suite, Ticketer}; -use rustls::crypto::CryptoProvider; -use rustls::server::{NoServerSessionStorage, ServerSessionMemoryCache, WebPkiClientVerifier}; -use rustls::RootCertStore; -use rustls::{ClientConfig, ClientConnection}; -use rustls::{ConnectionCommon, SideData}; -use rustls::{ServerConfig, ServerConnection}; - -fn duration_nanos(d: Duration) -> f64 { - (d.as_secs() as f64) + f64::from(d.subsec_nanos()) / 1e9 -} - -fn _bench(count: usize, name: &'static str, f_setup: Fsetup, f_test: Ftest) -where - Fsetup: Fn() -> S, - Ftest: Fn(S), -{ - let mut times = Vec::new(); - - for _ in 0..count { - let state = f_setup(); - let start = Instant::now(); - f_test(state); - times.push(duration_nanos(Instant::now().duration_since(start))); - } - - println!("{}", name); - println!("{:?}", times); -} - -fn time(mut f: F) -> f64 -where - F: FnMut(), -{ - let start = Instant::now(); - f(); - let end = Instant::now(); - duration_nanos(end.duration_since(start)) -} - -fn transfer(left: &mut L, right: &mut R, expect_data: Option) -> f64 -where - L: DerefMut + Deref>, - R: DerefMut + Deref>, - LS: SideData, - RS: SideData, -{ - let mut tls_buf = [0u8; 262144]; - let mut read_time = 0f64; - let mut data_left = expect_data; - let mut data_buf = [0u8; 8192]; - - loop { - let mut sz = 0; - - while left.wants_write() { - let written = left - .write_tls(&mut tls_buf[sz..].as_mut()) - .unwrap(); - if written == 0 { - break; - } - - sz += written; - } - - if sz == 0 { - return read_time; - } - - let mut offs = 0; - loop { - let start = Instant::now(); - match right.read_tls(&mut tls_buf[offs..sz].as_ref()) { - Ok(read) => { - right.process_new_packets().unwrap(); - offs += read; - } - Err(err) => { - panic!("error on transfer {}..{}: {}", offs, sz, err); - } - } - - if let Some(left) = &mut data_left { - loop { - let sz = match right.reader().read(&mut data_buf) { - Ok(sz) => sz, - Err(err) if err.kind() == io::ErrorKind::WouldBlock => break, - Err(err) => panic!("failed to read data: {}", err), - }; - - *left -= sz; - if *left == 0 { - break; - } - } - } - - let end = Instant::now(); - read_time += duration_nanos(end.duration_since(start)); - if sz == offs { - break; - } - } - } -} - -#[derive(PartialEq, Clone, Copy)] -enum ClientAuth { - No, - Yes, -} - -#[derive(PartialEq, Clone, Copy)] -enum ResumptionParam { - No, - SessionId, - Tickets, -} - -impl ResumptionParam { - fn label(&self) -> &'static str { - match *self { - Self::No => "no-resume", - Self::SessionId => "sessionid", - Self::Tickets => "tickets", - } - } -} - -// copied from tests/api.rs -#[derive(PartialEq, Clone, Copy, Debug)] -enum KeyType { - Rsa, - Ecdsa, - Ed25519, -} - -struct BenchmarkParam { - key_type: KeyType, - ciphersuite: rustls::SupportedCipherSuite, - version: &'static rustls::SupportedProtocolVersion, -} - -impl BenchmarkParam { - const fn new( - key_type: KeyType, - ciphersuite: rustls::SupportedCipherSuite, - version: &'static rustls::SupportedProtocolVersion, - ) -> Self { - Self { key_type, ciphersuite, version } - } -} - -static ALL_BENCHMARKS: &[BenchmarkParam] = &[ - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Rsa, - cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - &rustls::version::TLS12, - ), - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Ecdsa, - cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - &rustls::version::TLS12, - ), - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Rsa, - cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - &rustls::version::TLS12, - ), - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Rsa, - cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - &rustls::version::TLS12, - ), - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Rsa, - cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - &rustls::version::TLS12, - ), - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Ecdsa, - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - &rustls::version::TLS12, - ), - #[cfg(feature = "tls12")] - BenchmarkParam::new( - KeyType::Ecdsa, - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - &rustls::version::TLS12, - ), - BenchmarkParam::new( - KeyType::Rsa, - cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, - &rustls::version::TLS13, - ), - BenchmarkParam::new(KeyType::Rsa, cipher_suite::TLS13_AES_256_GCM_SHA384, &rustls::version::TLS13), - BenchmarkParam::new(KeyType::Rsa, cipher_suite::TLS13_AES_128_GCM_SHA256, &rustls::version::TLS13), - BenchmarkParam::new( - KeyType::Ecdsa, - cipher_suite::TLS13_AES_128_GCM_SHA256, - &rustls::version::TLS13, - ), - BenchmarkParam::new( - KeyType::Ed25519, - cipher_suite::TLS13_AES_128_GCM_SHA256, - &rustls::version::TLS13, - ), -]; - -impl KeyType { - fn path_for(&self, part: &str) -> String { - match self { - Self::Rsa => format!("test-ca/rsa/{}", part), - Self::Ecdsa => format!("test-ca/ecdsa/{}", part), - Self::Ed25519 => format!("test-ca/eddsa/{}", part), - } - } - - fn get_chain(&self) -> Vec> { - rustls_pemfile::certs(&mut io::BufReader::new( - fs::File::open(self.path_for("end.fullchain")).unwrap(), - )) - .map(|result| result.unwrap()) - .collect() - } - - fn get_key(&self) -> PrivateKeyDer<'static> { - rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(fs::File::open(self.path_for("end.key")).unwrap())) - .next() - .unwrap() - .unwrap() - .into() - } - - fn get_client_chain(&self) -> Vec> { - rustls_pemfile::certs(&mut io::BufReader::new( - fs::File::open(self.path_for("client.fullchain")).unwrap(), - )) - .map(|result| result.unwrap()) - .collect() - } - - fn get_client_key(&self) -> PrivateKeyDer<'static> { - rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(fs::File::open(self.path_for("client.key")).unwrap())) - .next() - .unwrap() - .unwrap() - .into() - } -} - -fn make_server_config( - params: &BenchmarkParam, - client_auth: ClientAuth, - resume: ResumptionParam, - max_fragment_size: Option, -) -> ServerConfig { - let client_auth = match client_auth { - ClientAuth::Yes => { - let roots = params.key_type.get_chain(); - let mut client_auth_roots = RootCertStore::empty(); - for root in roots { - client_auth_roots.add(root).unwrap(); - } - WebPkiClientVerifier::builder(client_auth_roots.into()) - .build() - .unwrap() - } - ClientAuth::No => WebPkiClientVerifier::no_client_auth(), - }; - - let mut cfg = ServerConfig::builder_with_provider(rustls_mbedcrypto_provider::mbedtls_crypto_provider().into()) - .with_protocol_versions(&[params.version]) - .unwrap() - .with_client_cert_verifier(client_auth) - .with_single_cert(params.key_type.get_chain(), params.key_type.get_key()) - .expect("bad certs/private key?"); - - if resume == ResumptionParam::SessionId { - cfg.session_storage = ServerSessionMemoryCache::new(128); - } else if resume == ResumptionParam::Tickets { - cfg.ticketer = Ticketer::new().unwrap(); - } else { - cfg.session_storage = Arc::new(NoServerSessionStorage {}); - } - - cfg.max_fragment_size = max_fragment_size; - cfg -} - -fn make_client_config(params: &BenchmarkParam, clientauth: ClientAuth, resume: ResumptionParam) -> ClientConfig { - let mut root_store = RootCertStore::empty(); - let mut rootbuf = io::BufReader::new(fs::File::open(params.key_type.path_for("ca.cert")).unwrap()); - root_store.add_parsable_certificates(rustls_pemfile::certs(&mut rootbuf).map(|result| result.unwrap())); - - let cfg = ClientConfig::builder_with_provider( - CryptoProvider { - cipher_suites: vec![params.ciphersuite], - ..rustls_mbedcrypto_provider::mbedtls_crypto_provider() - } - .into(), - ) - .with_protocol_versions(&[params.version]) - .unwrap() - .with_root_certificates(root_store); - - let mut cfg = if clientauth == ClientAuth::Yes { - cfg.with_client_auth_cert(params.key_type.get_client_chain(), params.key_type.get_client_key()) - .unwrap() - } else { - cfg.with_no_client_auth() - }; - - if resume != ResumptionParam::No { - cfg.resumption = Resumption::in_memory_sessions(128); - } else { - cfg.resumption = Resumption::disabled(); - } - - cfg -} - -fn apply_work_multiplier(work: u64) -> u64 { - let mul = match env::var("BENCH_MULTIPLIER") { - Ok(val) => val - .parse::() - .expect("invalid BENCH_MULTIPLIER value"), - Err(_) => 1., - }; - - ((work as f64) * mul).round() as u64 -} - -fn bench_handshake(params: &BenchmarkParam, clientauth: ClientAuth, resume: ResumptionParam) { - let client_config = Arc::new(make_client_config(params, clientauth, resume)); - let server_config = Arc::new(make_server_config(params, clientauth, resume, None)); - - assert!(params.ciphersuite.version() == params.version); - - let rounds = apply_work_multiplier(if resume == ResumptionParam::No { 512 } else { 4096 }); - let mut client_time = 0f64; - let mut server_time = 0f64; - - for _ in 0..rounds { - let server_name = "localhost".try_into().unwrap(); - let mut client = ClientConnection::new(Arc::clone(&client_config), server_name).unwrap(); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - - server_time += time(|| { - transfer(&mut client, &mut server, None); - }); - client_time += time(|| { - transfer(&mut server, &mut client, None); - }); - server_time += time(|| { - transfer(&mut client, &mut server, None); - }); - client_time += time(|| { - transfer(&mut server, &mut client, None); - }); - } - - println!( - "handshakes\t{:?}\t{:?}\t{:?}\tclient\t{}\t{}\t{:.2}\thandshake/s", - params.version, - params.key_type, - params.ciphersuite.suite(), - if clientauth == ClientAuth::Yes { - "mutual" - } else { - "server-auth" - }, - resume.label(), - (rounds as f64) / client_time - ); - println!( - "handshakes\t{:?}\t{:?}\t{:?}\tserver\t{}\t{}\t{:.2}\thandshake/s", - params.version, - params.key_type, - params.ciphersuite.suite(), - if clientauth == ClientAuth::Yes { - "mutual" - } else { - "server-auth" - }, - resume.label(), - (rounds as f64) / server_time - ); -} - -fn do_handshake_step(client: &mut ClientConnection, server: &mut ServerConnection) -> bool { - if server.is_handshaking() || client.is_handshaking() { - transfer(client, server, None); - transfer(server, client, None); - true - } else { - false - } -} - -fn do_handshake(client: &mut ClientConnection, server: &mut ServerConnection) { - while do_handshake_step(client, server) {} -} - -fn bench_bulk(params: &BenchmarkParam, plaintext_size: u64, max_fragment_size: Option) { - let client_config = Arc::new(make_client_config(params, ClientAuth::No, ResumptionParam::No)); - let server_config = Arc::new(make_server_config( - params, - ClientAuth::No, - ResumptionParam::No, - max_fragment_size, - )); - - let server_name = "localhost".try_into().unwrap(); - let mut client = ClientConnection::new(client_config, server_name).unwrap(); - client.set_buffer_limit(None); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - server.set_buffer_limit(None); - - do_handshake(&mut client, &mut server); - - let buf = vec![0; plaintext_size as usize]; - let total_data = apply_work_multiplier(if plaintext_size < 8192 { - 64 * 1024 * 1024 - } else { - 1024 * 1024 * 1024 - }); - let rounds = total_data / plaintext_size; - let mut time_send = 0f64; - let mut time_recv = 0f64; - - for _ in 0..rounds { - time_send += time(|| { - server.writer().write_all(&buf).unwrap(); - }); - - time_recv += transfer(&mut server, &mut client, Some(buf.len())); - } - - let mfs_str = format!( - "max_fragment_size:{}", - max_fragment_size - .map(|v| v.to_string()) - .unwrap_or_else(|| "default".to_string()) - ); - let total_mbs = ((plaintext_size * rounds) as f64) / (1024. * 1024.); - println!( - "bulk\t{:?}\t{:?}\t{}\tsend\t{:.2}\tMB/s", - params.version, - params.ciphersuite.suite(), - mfs_str, - total_mbs / time_send - ); - println!( - "bulk\t{:?}\t{:?}\t{}\trecv\t{:.2}\tMB/s", - params.version, - params.ciphersuite.suite(), - mfs_str, - total_mbs / time_recv - ); -} - -fn bench_memory(params: &BenchmarkParam, conn_count: u64) { - let client_config = Arc::new(make_client_config(params, ClientAuth::No, ResumptionParam::No)); - let server_config = Arc::new(make_server_config(params, ClientAuth::No, ResumptionParam::No, None)); - - // The target here is to end up with conn_count post-handshake - // server and client sessions. - let conn_count = (conn_count / 2) as usize; - let mut servers = Vec::with_capacity(conn_count); - let mut clients = Vec::with_capacity(conn_count); - - for _i in 0..conn_count { - servers.push(ServerConnection::new(Arc::clone(&server_config)).unwrap()); - let server_name = "localhost".try_into().unwrap(); - clients.push(ClientConnection::new(Arc::clone(&client_config), server_name).unwrap()); - } - - for _step in 0..5 { - for (client, server) in clients - .iter_mut() - .zip(servers.iter_mut()) - { - do_handshake_step(client, server); - } - } - - for client in clients.iter_mut() { - client - .writer() - .write_all(&[0u8; 1024]) - .unwrap(); - } - - for (client, server) in clients - .iter_mut() - .zip(servers.iter_mut()) - { - transfer(client, server, Some(1024)); - } -} - -fn lookup_matching_benches(name: &str) -> Vec<&BenchmarkParam> { - let r: Vec<&BenchmarkParam> = ALL_BENCHMARKS - .iter() - .filter(|params| format!("{:?}", params.ciphersuite.suite()).to_lowercase() == name.to_lowercase()) - .collect(); - - if r.is_empty() { - panic!("unknown suite {:?}", name); - } - - r -} - -fn selected_tests(mut args: env::Args) { - let mode = args - .next() - .expect("first argument must be mode"); - - match mode.as_ref() { - "bulk" => match args.next() { - Some(suite) => { - let len = args - .next() - .map(|arg| { - arg.parse::() - .expect("3rd arg must be plaintext size integer") - }) - .unwrap_or(1048576); - let mfs = args.next().map(|arg| { - arg.parse::() - .expect("4th arg must be max_fragment_size integer") - }); - for param in lookup_matching_benches(&suite).iter() { - bench_bulk(param, len, mfs); - } - } - None => { - panic!("bulk needs ciphersuite argument"); - } - }, - - "handshake" | "handshake-resume" | "handshake-ticket" => match args.next() { - Some(suite) => { - let resume = if mode == "handshake" { - ResumptionParam::No - } else if mode == "handshake-resume" { - ResumptionParam::SessionId - } else { - ResumptionParam::Tickets - }; - - for param in lookup_matching_benches(&suite).iter() { - bench_handshake(param, ClientAuth::No, resume); - } - } - None => { - panic!("handshake* needs ciphersuite argument"); - } - }, - - "memory" => match args.next() { - Some(suite) => { - let count = args - .next() - .map(|arg| { - arg.parse::() - .expect("3rd arg must be connection count integer") - }) - .unwrap_or(1000000); - for param in lookup_matching_benches(&suite).iter() { - bench_memory(param, count); - } - } - None => { - panic!("memory needs ciphersuite argument"); - } - }, - - _ => { - panic!("unsupported mode {:?}", mode); - } - } -} - -fn all_tests() { - for test in ALL_BENCHMARKS.iter() { - bench_bulk(test, 1024 * 1024, None); - bench_bulk(test, 1024 * 1024, Some(10000)); - bench_handshake(test, ClientAuth::No, ResumptionParam::No); - bench_handshake(test, ClientAuth::Yes, ResumptionParam::No); - bench_handshake(test, ClientAuth::No, ResumptionParam::SessionId); - bench_handshake(test, ClientAuth::Yes, ResumptionParam::SessionId); - bench_handshake(test, ClientAuth::No, ResumptionParam::Tickets); - bench_handshake(test, ClientAuth::Yes, ResumptionParam::Tickets); - } -} +mod bench_impl; fn main() { - let mut args = env::args(); - if args.len() > 1 { - args.next(); - selected_tests(args); - } else { - all_tests(); - } + bench_impl::main(); } diff --git a/rustls-mbedcrypto-provider/examples/internal/bench_impl.rs b/rustls-mbedcrypto-provider/examples/internal/bench_impl.rs new file mode 100644 index 0000000..1f9e8f2 --- /dev/null +++ b/rustls-mbedcrypto-provider/examples/internal/bench_impl.rs @@ -0,0 +1,633 @@ +// This program does assorted benchmarking of rustls. +// +// Note: we don't use any of the standard 'cargo bench', 'test::Bencher', +// etc. because it's unstable at the time of writing. + +use std::env; +use std::fs; +use std::io::{self, Read, Write}; +use std::ops::Deref; +use std::ops::DerefMut; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use pki_types::{CertificateDer, PrivateKeyDer}; + +use rustls::client::Resumption; +use rustls::crypto::ring::Ticketer; +use rustls::crypto::CryptoProvider; +use rustls::server::{NoServerSessionStorage, ServerSessionMemoryCache, WebPkiClientVerifier}; +use rustls::RootCertStore; +use rustls::{ClientConfig, ClientConnection}; +use rustls::{ConnectionCommon, SideData}; +use rustls::{ServerConfig, ServerConnection}; +use rustls_mbedcrypto_provider::cipher_suite; +use rustls_mbedcrypto_provider::mbedtls_crypto_provider as default_provider; + +pub fn main() { + let mut args = std::env::args(); + if args.len() > 1 { + args.next(); + selected_tests(args); + } else { + all_tests(); + } +} + +fn duration_nanos(d: Duration) -> f64 { + (d.as_secs() as f64) + f64::from(d.subsec_nanos()) / 1e9 +} + +fn _bench(count: usize, name: &'static str, f_setup: Fsetup, f_test: Ftest) +where + Fsetup: Fn() -> S, + Ftest: Fn(S), +{ + let mut times = Vec::new(); + + for _ in 0..count { + let state = f_setup(); + let start = Instant::now(); + f_test(state); + times.push(duration_nanos(Instant::now().duration_since(start))); + } + + println!("{}", name); + println!("{:?}", times); +} + +fn time(mut f: F) -> f64 +where + F: FnMut(), +{ + let start = Instant::now(); + f(); + let end = Instant::now(); + duration_nanos(end.duration_since(start)) +} + +fn transfer(left: &mut L, right: &mut R, expect_data: Option) -> f64 +where + L: DerefMut + Deref>, + R: DerefMut + Deref>, + LS: SideData, + RS: SideData, +{ + let mut tls_buf = [0u8; 262144]; + let mut read_time = 0f64; + let mut data_left = expect_data; + let mut data_buf = [0u8; 8192]; + + loop { + let mut sz = 0; + + while left.wants_write() { + let written = left + .write_tls(&mut tls_buf[sz..].as_mut()) + .unwrap(); + if written == 0 { + break; + } + + sz += written; + } + + if sz == 0 { + return read_time; + } + + let mut offs = 0; + loop { + let start = Instant::now(); + match right.read_tls(&mut tls_buf[offs..sz].as_ref()) { + Ok(read) => { + right.process_new_packets().unwrap(); + offs += read; + } + Err(err) => { + panic!("error on transfer {}..{}: {}", offs, sz, err); + } + } + + if let Some(left) = &mut data_left { + loop { + let sz = match right.reader().read(&mut data_buf) { + Ok(sz) => sz, + Err(err) if err.kind() == io::ErrorKind::WouldBlock => break, + Err(err) => panic!("failed to read data: {}", err), + }; + + *left -= sz; + if *left == 0 { + break; + } + } + } + + let end = Instant::now(); + read_time += duration_nanos(end.duration_since(start)); + if sz == offs { + break; + } + } + } +} + +#[derive(PartialEq, Clone, Copy)] +enum ClientAuth { + No, + Yes, +} + +#[derive(PartialEq, Clone, Copy)] +enum ResumptionParam { + No, + SessionId, + Tickets, +} + +impl ResumptionParam { + fn label(&self) -> &'static str { + match *self { + Self::No => "no-resume", + Self::SessionId => "sessionid", + Self::Tickets => "tickets", + } + } +} + +// copied from tests/api.rs +#[derive(PartialEq, Clone, Copy, Debug)] +enum KeyType { + Rsa, + Ecdsa, + // Ed25519 is not supported by *mbedtls* + // Ed25519, +} + +struct BenchmarkParam { + key_type: KeyType, + ciphersuite: rustls::SupportedCipherSuite, + version: &'static rustls::SupportedProtocolVersion, +} + +impl BenchmarkParam { + const fn new( + key_type: KeyType, + ciphersuite: rustls::SupportedCipherSuite, + version: &'static rustls::SupportedProtocolVersion, + ) -> Self { + Self { key_type, ciphersuite, version } + } +} + +static ALL_BENCHMARKS: &[BenchmarkParam] = &[ + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Rsa, + cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + &rustls::version::TLS12, + ), + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Ecdsa, + cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + &rustls::version::TLS12, + ), + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Rsa, + cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + &rustls::version::TLS12, + ), + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Rsa, + cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + &rustls::version::TLS12, + ), + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Rsa, + cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + &rustls::version::TLS12, + ), + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Ecdsa, + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + &rustls::version::TLS12, + ), + #[cfg(feature = "tls12")] + BenchmarkParam::new( + KeyType::Ecdsa, + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + &rustls::version::TLS12, + ), + BenchmarkParam::new( + KeyType::Rsa, + cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, + &rustls::version::TLS13, + ), + BenchmarkParam::new(KeyType::Rsa, cipher_suite::TLS13_AES_256_GCM_SHA384, &rustls::version::TLS13), + BenchmarkParam::new(KeyType::Rsa, cipher_suite::TLS13_AES_128_GCM_SHA256, &rustls::version::TLS13), + BenchmarkParam::new( + KeyType::Ecdsa, + cipher_suite::TLS13_AES_128_GCM_SHA256, + &rustls::version::TLS13, + ), + // Ed25519 is not supported by *mbedtls* + // BenchmarkParam::new( + // KeyType::Ed25519, + // cipher_suite::TLS13_AES_128_GCM_SHA256, + // &rustls::version::TLS13, + // ), +]; + +impl KeyType { + fn path_for(&self, part: &str) -> String { + match self { + Self::Rsa => format!("test-ca/rsa/{}", part), + Self::Ecdsa => format!("test-ca/ecdsa/{}", part), + // Ed25519 is not supported by *mbedtls* + // Self::Ed25519 => format!("test-ca/eddsa/{}", part), + } + } + + fn get_chain(&self) -> Vec> { + rustls_pemfile::certs(&mut io::BufReader::new( + fs::File::open(self.path_for("end.fullchain")).unwrap(), + )) + .map(|result| result.unwrap()) + .collect() + } + + fn get_key(&self) -> PrivateKeyDer<'static> { + rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(fs::File::open(self.path_for("end.key")).unwrap())) + .next() + .unwrap() + .unwrap() + .into() + } + + fn get_client_chain(&self) -> Vec> { + rustls_pemfile::certs(&mut io::BufReader::new( + fs::File::open(self.path_for("client.fullchain")).unwrap(), + )) + .map(|result| result.unwrap()) + .collect() + } + + fn get_client_key(&self) -> PrivateKeyDer<'static> { + rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(fs::File::open(self.path_for("client.key")).unwrap())) + .next() + .unwrap() + .unwrap() + .into() + } +} + +fn make_server_config( + params: &BenchmarkParam, + client_auth: ClientAuth, + resume: ResumptionParam, + max_fragment_size: Option, +) -> ServerConfig { + let provider = Arc::new(default_provider()); + let client_auth = match client_auth { + ClientAuth::Yes => { + let roots = params.key_type.get_chain(); + let mut client_auth_roots = RootCertStore::empty(); + for root in roots { + client_auth_roots.add(root).unwrap(); + } + WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone()) + .build() + .unwrap() + } + ClientAuth::No => WebPkiClientVerifier::no_client_auth(), + }; + + let mut cfg = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[params.version]) + .unwrap() + .with_client_cert_verifier(client_auth) + .with_single_cert(params.key_type.get_chain(), params.key_type.get_key()) + .expect("bad certs/private key?"); + + if resume == ResumptionParam::SessionId { + cfg.session_storage = ServerSessionMemoryCache::new(128); + } else if resume == ResumptionParam::Tickets { + cfg.ticketer = Ticketer::new().unwrap(); + } else { + cfg.session_storage = Arc::new(NoServerSessionStorage {}); + } + + cfg.max_fragment_size = max_fragment_size; + cfg +} + +fn make_client_config(params: &BenchmarkParam, clientauth: ClientAuth, resume: ResumptionParam) -> ClientConfig { + let mut root_store = RootCertStore::empty(); + let mut rootbuf = io::BufReader::new(fs::File::open(params.key_type.path_for("ca.cert")).unwrap()); + root_store.add_parsable_certificates(rustls_pemfile::certs(&mut rootbuf).map(|result| result.unwrap())); + + let cfg = ClientConfig::builder_with_provider( + CryptoProvider { cipher_suites: vec![params.ciphersuite], ..default_provider() }.into(), + ) + .with_protocol_versions(&[params.version]) + .unwrap() + .with_root_certificates(root_store); + + let mut cfg = if clientauth == ClientAuth::Yes { + cfg.with_client_auth_cert(params.key_type.get_client_chain(), params.key_type.get_client_key()) + .unwrap() + } else { + cfg.with_no_client_auth() + }; + + if resume != ResumptionParam::No { + cfg.resumption = Resumption::in_memory_sessions(128); + } else { + cfg.resumption = Resumption::disabled(); + } + + cfg +} + +fn apply_work_multiplier(work: u64) -> u64 { + let mul = match env::var("BENCH_MULTIPLIER") { + Ok(val) => val + .parse::() + .expect("invalid BENCH_MULTIPLIER value"), + Err(_) => 1., + }; + + ((work as f64) * mul).round() as u64 +} + +fn bench_handshake(params: &BenchmarkParam, clientauth: ClientAuth, resume: ResumptionParam) { + let client_config = Arc::new(make_client_config(params, clientauth, resume)); + let server_config = Arc::new(make_server_config(params, clientauth, resume, None)); + + assert!(params.ciphersuite.version() == params.version); + + let rounds = apply_work_multiplier(if resume == ResumptionParam::No { 512 } else { 4096 }); + let mut client_time = 0f64; + let mut server_time = 0f64; + + for _ in 0..rounds { + let server_name = "localhost".try_into().unwrap(); + let mut client = ClientConnection::new(Arc::clone(&client_config), server_name).unwrap(); + let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); + + server_time += time(|| { + transfer(&mut client, &mut server, None); + }); + client_time += time(|| { + transfer(&mut server, &mut client, None); + }); + server_time += time(|| { + transfer(&mut client, &mut server, None); + }); + client_time += time(|| { + transfer(&mut server, &mut client, None); + }); + } + + println!( + "handshakes\t{:?}\t{:?}\t{:?}\tclient\t{}\t{}\t{:.2}\thandshake/s", + params.version, + params.key_type, + params.ciphersuite.suite(), + if clientauth == ClientAuth::Yes { + "mutual" + } else { + "server-auth" + }, + resume.label(), + (rounds as f64) / client_time + ); + println!( + "handshakes\t{:?}\t{:?}\t{:?}\tserver\t{}\t{}\t{:.2}\thandshake/s", + params.version, + params.key_type, + params.ciphersuite.suite(), + if clientauth == ClientAuth::Yes { + "mutual" + } else { + "server-auth" + }, + resume.label(), + (rounds as f64) / server_time + ); +} + +fn do_handshake_step(client: &mut ClientConnection, server: &mut ServerConnection) -> bool { + if server.is_handshaking() || client.is_handshaking() { + transfer(client, server, None); + transfer(server, client, None); + true + } else { + false + } +} + +fn do_handshake(client: &mut ClientConnection, server: &mut ServerConnection) { + while do_handshake_step(client, server) {} +} + +fn bench_bulk(params: &BenchmarkParam, plaintext_size: u64, max_fragment_size: Option) { + let client_config = Arc::new(make_client_config(params, ClientAuth::No, ResumptionParam::No)); + let server_config = Arc::new(make_server_config( + params, + ClientAuth::No, + ResumptionParam::No, + max_fragment_size, + )); + + let server_name = "localhost".try_into().unwrap(); + let mut client = ClientConnection::new(client_config, server_name).unwrap(); + client.set_buffer_limit(None); + let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); + server.set_buffer_limit(None); + + do_handshake(&mut client, &mut server); + + let buf = vec![0; plaintext_size as usize]; + let total_data = apply_work_multiplier(if plaintext_size < 8192 { + 64 * 1024 * 1024 + } else { + 1024 * 1024 * 1024 + }); + let rounds = total_data / plaintext_size; + let mut time_send = 0f64; + let mut time_recv = 0f64; + + for _ in 0..rounds { + time_send += time(|| { + server.writer().write_all(&buf).unwrap(); + }); + + time_recv += transfer(&mut server, &mut client, Some(buf.len())); + } + + let mfs_str = format!( + "max_fragment_size:{}", + max_fragment_size + .map(|v| v.to_string()) + .unwrap_or_else(|| "default".to_string()) + ); + let total_mbs = ((plaintext_size * rounds) as f64) / (1024. * 1024.); + println!( + "bulk\t{:?}\t{:?}\t{}\tsend\t{:.2}\tMB/s", + params.version, + params.ciphersuite.suite(), + mfs_str, + total_mbs / time_send + ); + println!( + "bulk\t{:?}\t{:?}\t{}\trecv\t{:.2}\tMB/s", + params.version, + params.ciphersuite.suite(), + mfs_str, + total_mbs / time_recv + ); +} + +fn bench_memory(params: &BenchmarkParam, conn_count: u64) { + let client_config = Arc::new(make_client_config(params, ClientAuth::No, ResumptionParam::No)); + let server_config = Arc::new(make_server_config(params, ClientAuth::No, ResumptionParam::No, None)); + + // The target here is to end up with conn_count post-handshake + // server and client sessions. + let conn_count = (conn_count / 2) as usize; + let mut servers = Vec::with_capacity(conn_count); + let mut clients = Vec::with_capacity(conn_count); + + for _i in 0..conn_count { + servers.push(ServerConnection::new(Arc::clone(&server_config)).unwrap()); + let server_name = "localhost".try_into().unwrap(); + clients.push(ClientConnection::new(Arc::clone(&client_config), server_name).unwrap()); + } + + for _step in 0..5 { + for (client, server) in clients + .iter_mut() + .zip(servers.iter_mut()) + { + do_handshake_step(client, server); + } + } + + for client in clients.iter_mut() { + client + .writer() + .write_all(&[0u8; 1024]) + .unwrap(); + } + + for (client, server) in clients + .iter_mut() + .zip(servers.iter_mut()) + { + transfer(client, server, Some(1024)); + } +} + +fn lookup_matching_benches(name: &str) -> Vec<&BenchmarkParam> { + let r: Vec<&BenchmarkParam> = ALL_BENCHMARKS + .iter() + .filter(|params| format!("{:?}", params.ciphersuite.suite()).to_lowercase() == name.to_lowercase()) + .collect(); + + if r.is_empty() { + panic!("unknown suite {:?}", name); + } + + r +} + +fn selected_tests(mut args: env::Args) { + let mode = args + .next() + .expect("first argument must be mode"); + + match mode.as_ref() { + "bulk" => match args.next() { + Some(suite) => { + let len = args + .next() + .map(|arg| { + arg.parse::() + .expect("3rd arg must be plaintext size integer") + }) + .unwrap_or(1048576); + let mfs = args.next().map(|arg| { + arg.parse::() + .expect("4th arg must be max_fragment_size integer") + }); + for param in lookup_matching_benches(&suite).iter() { + bench_bulk(param, len, mfs); + } + } + None => { + panic!("bulk needs ciphersuite argument"); + } + }, + + "handshake" | "handshake-resume" | "handshake-ticket" => match args.next() { + Some(suite) => { + let resume = if mode == "handshake" { + ResumptionParam::No + } else if mode == "handshake-resume" { + ResumptionParam::SessionId + } else { + ResumptionParam::Tickets + }; + + for param in lookup_matching_benches(&suite).iter() { + bench_handshake(param, ClientAuth::No, resume); + } + } + None => { + panic!("handshake* needs ciphersuite argument"); + } + }, + + "memory" => match args.next() { + Some(suite) => { + let count = args + .next() + .map(|arg| { + arg.parse::() + .expect("3rd arg must be connection count integer") + }) + .unwrap_or(1000000); + for param in lookup_matching_benches(&suite).iter() { + bench_memory(param, count); + } + } + None => { + panic!("memory needs ciphersuite argument"); + } + }, + + _ => { + panic!("unsupported mode {:?}", mode); + } + } +} + +fn all_tests() { + for test in ALL_BENCHMARKS.iter() { + bench_bulk(test, 1024 * 1024, None); + bench_bulk(test, 1024 * 1024, Some(10000)); + bench_handshake(test, ClientAuth::No, ResumptionParam::No); + bench_handshake(test, ClientAuth::Yes, ResumptionParam::No); + bench_handshake(test, ClientAuth::No, ResumptionParam::SessionId); + bench_handshake(test, ClientAuth::Yes, ResumptionParam::SessionId); + bench_handshake(test, ClientAuth::No, ResumptionParam::Tickets); + bench_handshake(test, ClientAuth::Yes, ResumptionParam::Tickets); + } +} diff --git a/rustls-mbedcrypto-provider/src/error.rs b/rustls-mbedcrypto-provider/src/error.rs index 0094290..746ffd8 100644 --- a/rustls-mbedcrypto-provider/src/error.rs +++ b/rustls-mbedcrypto-provider/src/error.rs @@ -5,9 +5,11 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use alloc::format; +use alloc::sync::Arc; -/// Convert a [`mbedtls::Error`] to a [`rustls::Error::General`] error. -pub(crate) fn mbedtls_err_to_rustls_general_error(err: mbedtls::Error) -> rustls::Error { - rustls::Error::General(format!("Got mbedtls error: {}", err)) +use rustls::OtherError; + +/// Convert a [`mbedtls::Error`] to a [`rustls::Error::Other`] error. +pub(crate) fn mbedtls_err_to_rustls_error(err: mbedtls::Error) -> rustls::Error { + OtherError(Arc::new(err)).into() } diff --git a/rustls-mbedcrypto-provider/src/hash.rs b/rustls-mbedcrypto-provider/src/hash.rs index 25c15d5..1b74fa9 100644 --- a/rustls-mbedcrypto-provider/src/hash.rs +++ b/rustls-mbedcrypto-provider/src/hash.rs @@ -155,3 +155,47 @@ pub(crate) fn hash(hash_algo: &'static Algorithm, data: &[u8]) -> Vec { } } } + +#[cfg(bench)] +mod benchmarks { + + #[bench] + fn bench_sha_256_hash(b: &mut test::Bencher) { + bench_hash(b, &super::SHA256); + } + + #[bench] + fn bench_sha_384_hash(b: &mut test::Bencher) { + bench_hash(b, &super::SHA384); + } + + #[bench] + fn bench_sha_256_hash_multi_parts(b: &mut test::Bencher) { + bench_hash_multi_parts(b, &super::SHA256); + } + + #[bench] + fn bench_sha_384_hash_multi_parts(b: &mut test::Bencher) { + bench_hash_multi_parts(b, &super::SHA384); + } + + fn bench_hash(b: &mut test::Bencher, hash: &super::Hash) { + use super::hash::Hash; + let input = [123u8; 1024 * 16]; + b.iter(|| { + test::black_box(hash.hash(&input)); + }); + } + + fn bench_hash_multi_parts(b: &mut test::Bencher, hash: &super::Hash) { + use super::hash::Hash; + let input = [123u8; 1024 * 16]; + b.iter(|| { + let mut ctx = hash.start(); + for i in 0..16 { + test::black_box(ctx.update(&input[i * 1024..(i + 1) * 1024])); + } + test::black_box(ctx.finish()) + }); + } +} diff --git a/rustls-mbedcrypto-provider/src/hmac.rs b/rustls-mbedcrypto-provider/src/hmac.rs index 5a39468..0bb6e0c 100644 --- a/rustls-mbedcrypto-provider/src/hmac.rs +++ b/rustls-mbedcrypto-provider/src/hmac.rs @@ -10,7 +10,6 @@ use alloc::boxed::Box; use alloc::vec; use alloc::vec::Vec; use rustls::crypto; -use std::sync::Mutex; /// HMAC using SHA-256. pub(crate) static HMAC_SHA256: Hmac = Hmac(&super::hash::MBED_SHA_256); @@ -22,7 +21,7 @@ pub(crate) struct Hmac(&'static super::hash::Algorithm); impl crypto::hmac::Hmac for Hmac { fn with_key(&self, key: &[u8]) -> Box { - Box::new(HmacContext(MbedHmacContext::new(self.0, key))) + Box::new(HmacKey(MbedHmacKey::new(self.0, key))) } fn hash_output_len(&self) -> usize { @@ -30,17 +29,17 @@ impl crypto::hmac::Hmac for Hmac { } } -struct HmacContext(MbedHmacContext); +struct HmacKey(MbedHmacKey); -impl crypto::hmac::Key for HmacContext { +impl crypto::hmac::Key for HmacKey { fn sign_concat(&self, first: &[u8], middle: &[&[u8]], last: &[u8]) -> crypto::hmac::Tag { - let mut ctx = MbedHmacContext::new(self.0.hmac_algo, &self.0.key); + let mut ctx = self.0.starts(); ctx.update(first); for m in middle { ctx.update(m); } ctx.update(last); - crypto::hmac::Tag::new(&ctx.finalize()) + crypto::hmac::Tag::new(&ctx.finish()) } fn tag_len(&self) -> usize { @@ -48,54 +47,48 @@ impl crypto::hmac::Key for HmacContext { } } -struct MbedHmacContext { - state: Mutex, +struct MbedHmacKey { hmac_algo: &'static super::hash::Algorithm, - key: Vec, + /// use [`crypto::hmac::Tag`] for saving key material, since they have same max size. + key: crypto::hmac::Tag, } -impl MbedHmacContext { +impl MbedHmacKey { pub(crate) fn new(hmac_algo: &'static super::hash::Algorithm, key: &[u8]) -> Self { - Self { - hmac_algo, - state: Mutex::new(mbedtls::hash::Hmac::new(hmac_algo.hash_type, key).expect("input validated")), - key: key.to_vec(), + Self { hmac_algo, key: crypto::hmac::Tag::new(key) } + } + + pub(crate) fn starts(&self) -> MbedHmacContext { + MbedHmacContext { + hmac_algo: self.hmac_algo, + ctx: mbedtls::hash::Hmac::new(self.hmac_algo.hash_type, self.key.as_ref()).expect("input validated"), } } +} + +struct MbedHmacContext { + hmac_algo: &'static super::hash::Algorithm, + ctx: mbedtls::hash::Hmac, +} +impl MbedHmacContext { /// Since the trait does not provider a way to return error, empty vector is returned when getting error from `mbedtls`. - pub(crate) fn finalize(self) -> Vec { - match self.state.into_inner() { - Ok(ctx) => { - let mut out = vec![0u8; self.hmac_algo.output_len]; - match ctx.finish(&mut out) { - Ok(_) => out, - Err(_err) => { - error!("Failed to finalize hmac, mbedtls error: {:?}", _err); - vec![] - } - } - } + pub(crate) fn finish(self) -> Vec { + let mut out = vec![0u8; self.hmac_algo.output_len]; + match self.ctx.finish(&mut out) { + Ok(_) => out, Err(_err) => { - error!("Failed to get lock, error: {:?}", _err); + error!("Failed to finish hmac, mbedtls error: {:?}", _err); vec![] } } } pub(crate) fn update(&mut self, data: &[u8]) { - if data.is_empty() { - return; - } - match self.state.lock().as_mut() { - Ok(ctx) => match ctx.update(data) { - Ok(_) => {} - Err(_err) => { - error!("Failed to update hmac, mbedtls error: {:?}", _err); - } - }, + match self.ctx.update(data) { + Ok(_) => {} Err(_err) => { - error!("Failed to get lock, error: {:?}", _err); + error!("Failed to update hmac, mbedtls error: {:?}", _err); } } } diff --git a/rustls-mbedcrypto-provider/src/kx.rs b/rustls-mbedcrypto-provider/src/kx.rs index dbcd47f..3ce7971 100644 --- a/rustls-mbedcrypto-provider/src/kx.rs +++ b/rustls-mbedcrypto-provider/src/kx.rs @@ -5,17 +5,17 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +use std::sync::OnceLock; + use super::agreement; -use crate::error::mbedtls_err_to_rustls_general_error; +use crate::error::mbedtls_err_to_rustls_error; use alloc::boxed::Box; use alloc::fmt; use alloc::format; -use alloc::vec; use alloc::vec::Vec; use crypto::SupportedKxGroup; use mbedtls::{ - bignum::Mpi, ecp::EcPoint, pk::{EcGroup, Pk as PkMbed}, }; @@ -42,30 +42,14 @@ impl fmt::Debug for KxGroup { impl SupportedKxGroup for KxGroup { fn start(&self) -> Result, Error> { - let mut pk = PkMbed::generate_ec( - &mut super::rng::rng_new().ok_or(rustls::crypto::GetRandomFailed)?, - self.agreement_algorithm.group_id, - ) - .map_err(|err| rustls::Error::General(format!("Encountered error when generating ec key, mbedtls error: {}", err)))?; - - fn get_key_pair(pk: &mut PkMbed, kx_group: &KxGroup) -> Result { - let group = EcGroup::new(kx_group.agreement_algorithm.group_id)?; - let pub_key = pk - .ec_public()? - .to_binary(&group, false)?; - let priv_key = pk.ec_private()?.to_binary()?; - Ok(KeyExchange { - name: kx_group.name, - agreement_algorithm: kx_group.agreement_algorithm, - priv_key, - pub_key, - }) - } - - match get_key_pair(&mut pk, self) { - Ok(group) => Ok(Box::new(group)), - Err(err) => Err(rustls::Error::General(format!("Unexpected mbedtls error: {}", err))), - } + let priv_key = generate_ec_key(self.agreement_algorithm.group_id)?; + + Ok(Box::new(KeyExchange { + name: self.name, + agreement_algorithm: self.agreement_algorithm, + priv_key, + pub_key: OnceLock::new(), + })) } fn name(&self) -> NamedGroup { @@ -73,6 +57,12 @@ impl SupportedKxGroup for KxGroup { } } +#[inline] +fn generate_ec_key(group_id: mbedtls::pk::EcGroupId) -> Result { + PkMbed::generate_ec(&mut super::rng::rng_new().ok_or(rustls::crypto::GetRandomFailed)?, group_id) + .map_err(|err| rustls::Error::General(format!("Got error when generating ec key, mbedtls error: {}", err))) +} + /// Ephemeral ECDH on curve25519 (see RFC7748) pub static X25519: &dyn SupportedKxGroup = &KxGroup { name: NamedGroup::X25519, agreement_algorithm: &agreement::X25519 }; @@ -97,54 +87,119 @@ struct KeyExchange { name: NamedGroup, /// The corresponding [`agreement::Algorithm`] agreement_algorithm: &'static agreement::Algorithm, - /// Binary format [`Mpi`] - priv_key: Vec, - /// Binary format [`EcPoint`] without compression - pub_key: Vec, + /// Private key + priv_key: PkMbed, + /// Public key in binary format [`EcPoint`] without compression + pub_key: OnceLock>, +} + +impl KeyExchange { + fn get_pub_key(&self) -> mbedtls::Result> { + let group = EcGroup::new(self.agreement_algorithm.group_id)?; + self.priv_key + .ec_public()? + .to_binary(&group, false) + } } impl crypto::ActiveKeyExchange for KeyExchange { /// Completes the key exchange, given the peer's public key. - fn complete(self: Box, peer_public_key: &[u8]) -> Result { - // Get private key from self data + fn complete(mut self: Box, peer_public_key: &[u8]) -> Result { let group_id = self.agreement_algorithm.group_id; - let ec_group = EcGroup::new(group_id).map_err(mbedtls_err_to_rustls_general_error)?; - let private_key = Mpi::from_binary(&self.priv_key).map_err(mbedtls_err_to_rustls_general_error)?; - let mut sk = - PkMbed::private_from_ec_components(ec_group.clone(), private_key).map_err(mbedtls_err_to_rustls_general_error)?; if peer_public_key.len() != self.agreement_algorithm.public_key_len { - return Err(Error::General(format!( - "Failed to validate {:?} comping peer public key, invalid length", - group_id - ))); + return Err(rustls::PeerMisbehaved::InvalidKeyShare.into()); } - let public_point = - EcPoint::from_binary_no_compress(&ec_group, peer_public_key).map_err(mbedtls_err_to_rustls_general_error)?; - let peer_pk = PkMbed::public_from_ec_components(ec_group, public_point).map_err(mbedtls_err_to_rustls_general_error)?; - - let mut shared_secret = vec![ - 0u8; - self.agreement_algorithm - .max_signature_len - ]; - let len = sk + + let peer_pk = parse_peer_public_key(group_id, peer_public_key).map_err(mbedtls_err_to_rustls_error)?; + + let mut shared_key = [0u8; mbedtls::pk::ECDSA_MAX_LEN]; + let shared_key = &mut shared_key[..self + .agreement_algorithm + .max_signature_len]; + let len = self + .priv_key .agree( &peer_pk, - &mut shared_secret, + shared_key, &mut super::rng::rng_new().ok_or(rustls::crypto::GetRandomFailed)?, ) - .map_err(mbedtls_err_to_rustls_general_error)?; - Ok(crypto::SharedSecret::from(&shared_secret[..len])) + .map_err(mbedtls_err_to_rustls_error)?; + Ok(crypto::SharedSecret::from(&shared_key[..len])) } - /// Return the group being used. + /// Return the public key being used. fn pub_key(&self) -> &[u8] { - &self.pub_key + self.pub_key + .get_or_init(|| self.get_pub_key().unwrap_or_default()) } - /// Return the public key being used. + /// Return the group being used. fn group(&self) -> NamedGroup { self.name } } + +#[inline] +fn parse_peer_public_key(group_id: mbedtls::pk::EcGroupId, peer_public_key: &[u8]) -> Result { + let ec_group = EcGroup::new(group_id)?; + let public_point = EcPoint::from_binary_no_compress(&ec_group, peer_public_key)?; + PkMbed::public_from_ec_components(ec_group, public_point) +} + +#[cfg(bench)] +mod benchmarks { + + #[bench] + fn bench_ecdh_p256(b: &mut test::Bencher) { + bench_any(b, super::SECP256R1); + } + + #[bench] + fn bench_ecdh_p384(b: &mut test::Bencher) { + bench_any(b, super::SECP384R1); + } + + #[bench] + fn bench_ecdh_p521(b: &mut test::Bencher) { + bench_any(b, super::SECP521R1); + } + + #[bench] + fn bench_x25519(b: &mut test::Bencher) { + bench_any(b, super::X25519); + } + + fn bench_any(b: &mut test::Bencher, kxg: &dyn super::SupportedKxGroup) { + b.iter(|| { + let akx = kxg.start().unwrap(); + let pub_key = akx.pub_key().to_vec(); + test::black_box(akx.complete(&pub_key).unwrap()); + }); + } + + #[bench] + fn bench_ecdh_p256_start(b: &mut test::Bencher) { + let kxg = super::SECP256R1; + b.iter(|| { + test::black_box(kxg.start().unwrap()); + }); + } + + #[bench] + fn bench_ecdh_p256_gen_private_key(b: &mut test::Bencher) { + b.iter(|| { + test::black_box(super::generate_ec_key(mbedtls::pk::EcGroupId::SecP256R1).unwrap()); + }); + } + + #[bench] + fn bench_ecdh_p256_parse_peer_pub_key(b: &mut test::Bencher) { + let kxg = super::SECP256R1; + let akx = kxg.start().unwrap(); + let pub_key = akx.pub_key().to_vec(); + b.iter(|| { + test::black_box(super::parse_peer_public_key(mbedtls::pk::EcGroupId::SecP256R1, &pub_key).unwrap()); + }); + } +} diff --git a/rustls-mbedcrypto-provider/src/lib.rs b/rustls-mbedcrypto-provider/src/lib.rs index 9a75c2d..4361939 100644 --- a/rustls-mbedcrypto-provider/src/lib.rs +++ b/rustls-mbedcrypto-provider/src/lib.rs @@ -48,16 +48,24 @@ // Enable documentation for all features on docs.rs #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(bench, feature(test))] -#![cfg_attr(not(test), no_std)] +// TODO: enable this once we support use mbedtls without `std` +// #![cfg_attr(not(test), no_std)] extern crate alloc; + // This `extern crate` plus the `#![no_std]` attribute changes the default prelude from // `std::prelude` to `core::prelude`. That forces one to _explicitly_ import (`use`) everything that // is in `std::prelude` but not in `core::prelude`. This helps maintain no-std support as even // developers that are not interested in, or aware of, no-std support and / or that never run // `cargo build --no-default-features` locally will get errors when they rely on `std::prelude` API. -#[cfg(not(test))] -extern crate std; +// TODO: enable this once we support use mbedtls without `std` +// #[cfg(not(test))] +// extern crate std; + +// Import `test` sysroot crate for `Bencher` definitions. +#[cfg(bench)] +#[allow(unused_extern_crates)] +extern crate test; // log for logging (optional). #[cfg(feature = "logging")] diff --git a/rustls-mbedcrypto-provider/src/tls12.rs b/rustls-mbedcrypto-provider/src/tls12.rs index 00b2349..832ec68 100644 --- a/rustls-mbedcrypto-provider/src/tls12.rs +++ b/rustls-mbedcrypto-provider/src/tls12.rs @@ -5,7 +5,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use crate::error::mbedtls_err_to_rustls_general_error; +use crate::error::mbedtls_err_to_rustls_error; use alloc::boxed::Box; use alloc::vec::Vec; use mbedtls::cipher::raw::{CipherId, CipherMode, CipherType}; @@ -217,11 +217,11 @@ impl MessageDecrypter for GcmMessageDecrypter { let dec_key = self.dec_key.as_ref(); let cipher = Cipher::::new(CipherId::Aes, CipherMode::GCM, (dec_key.len() * 8) as _) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let cipher = cipher .set_key_iv(dec_key, &nonce) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let tag_offset = payload .len() @@ -237,7 +237,7 @@ impl MessageDecrypter for GcmMessageDecrypter { | mbedtls::Error::ChachapolyAuthFailed | mbedtls::Error::CipherAuthFailed | mbedtls::Error::GcmAuthFailed => rustls::Error::DecryptError, - _ => mbedtls_err_to_rustls_general_error(err), + _ => mbedtls_err_to_rustls_error(err), })?; if plain_len > MAX_FRAGMENT_LEN { return Err(Error::PeerSentOversizedRecord); @@ -263,10 +263,10 @@ impl MessageEncrypter for GcmMessageEncrypter { let enc_key = self.enc_key.as_ref(); let cipher = Cipher::::new(CipherId::Aes, CipherMode::GCM, (enc_key.len() * 8) as _) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let cipher = cipher .set_key_iv(enc_key, &nonce) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; cipher .encrypt_auth_inplace(&aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..], &mut tag) @@ -275,7 +275,7 @@ impl MessageEncrypter for GcmMessageEncrypter { | mbedtls::Error::ChachapolyAuthFailed | mbedtls::Error::CipherAuthFailed | mbedtls::Error::GcmAuthFailed => rustls::Error::EncryptError, - _ => mbedtls_err_to_rustls_general_error(err), + _ => mbedtls_err_to_rustls_error(err), })?; payload.extend(tag); @@ -324,11 +324,11 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter { CipherMode::CHACHAPOLY, (dec_key.len() * 8) as _, ) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let cipher = cipher .set_key_iv(dec_key, &nonce) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let tag_offset = payload .len() @@ -344,7 +344,7 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter { | mbedtls::Error::ChachapolyAuthFailed | mbedtls::Error::CipherAuthFailed | mbedtls::Error::GcmAuthFailed => rustls::Error::DecryptError, - _ => mbedtls_err_to_rustls_general_error(err), + _ => mbedtls_err_to_rustls_error(err), })?; if plain_len > MAX_FRAGMENT_LEN { @@ -371,11 +371,11 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter { CipherMode::CHACHAPOLY, (enc_key.len() * 8) as _, ) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let cipher = cipher .set_key_iv(enc_key, &nonce) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; cipher .encrypt_auth_inplace(&aad, &mut payload, &mut tag) @@ -384,7 +384,7 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter { | mbedtls::Error::ChachapolyAuthFailed | mbedtls::Error::CipherAuthFailed | mbedtls::Error::GcmAuthFailed => rustls::Error::EncryptError, - _ => mbedtls_err_to_rustls_general_error(err), + _ => mbedtls_err_to_rustls_error(err), })?; payload.extend(tag); diff --git a/rustls-mbedcrypto-provider/src/tls13.rs b/rustls-mbedcrypto-provider/src/tls13.rs index 793c8a7..0ed1966 100644 --- a/rustls-mbedcrypto-provider/src/tls13.rs +++ b/rustls-mbedcrypto-provider/src/tls13.rs @@ -6,7 +6,7 @@ */ use super::aead; -use crate::error::mbedtls_err_to_rustls_general_error; +use crate::error::mbedtls_err_to_rustls_error; use alloc::boxed::Box; use alloc::string::String; use alloc::vec::Vec; @@ -120,11 +120,11 @@ impl MessageEncrypter for Tls13MessageEncrypter { self.aead_algorithm.cipher_mode, (enc_key.len() * 8) as _, ) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let cipher = cipher .set_key_iv(enc_key, &nonce) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; cipher .encrypt_auth_inplace(&aad, &mut payload, &mut tag) @@ -133,7 +133,7 @@ impl MessageEncrypter for Tls13MessageEncrypter { | mbedtls::Error::ChachapolyAuthFailed | mbedtls::Error::CipherAuthFailed | mbedtls::Error::GcmAuthFailed => rustls::Error::EncryptError, - _ => mbedtls_err_to_rustls_general_error(err), + _ => mbedtls_err_to_rustls_error(err), })?; payload.extend(tag); @@ -165,11 +165,11 @@ impl MessageDecrypter for Tls13MessageDecrypter { self.aead_algorithm.cipher_mode, (dec_key.len() * 8) as _, ) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let cipher = cipher .set_key_iv(dec_key, &nonce) - .map_err(mbedtls_err_to_rustls_general_error)?; + .map_err(mbedtls_err_to_rustls_error)?; let tag_offset = payload .len() @@ -185,7 +185,7 @@ impl MessageDecrypter for Tls13MessageDecrypter { | mbedtls::Error::ChachapolyAuthFailed | mbedtls::Error::CipherAuthFailed | mbedtls::Error::GcmAuthFailed => rustls::Error::DecryptError, - _ => mbedtls_err_to_rustls_general_error(err), + _ => mbedtls_err_to_rustls_error(err), })?; payload.truncate(plain_len); msg.into_tls13_unpadded_message()