Skip to content

Commit

Permalink
Merge pull request #142 from ngrok/bob/session-host-certs
Browse files Browse the repository at this point in the history
add root_cas
  • Loading branch information
bobzilladev authored May 21, 2024
2 parents 49a5d16 + e4991ae commit ffc45ab
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 20 deletions.
2 changes: 1 addition & 1 deletion ngrok/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ngrok"
version = "0.14.0-pre.12"
version = "0.14.0-pre.13"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "The ngrok agent SDK"
Expand Down
1 change: 1 addition & 0 deletions ngrok/examples/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ async fn main() -> anyhow::Result<()> {
let sess = ngrok::Session::builder()
.authtoken_from_env()
.metadata("Online in One Line")
// .root_cas("trusted")?
.connect()
.await?;

Expand Down
54 changes: 54 additions & 0 deletions ngrok/src/online_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,60 @@ async fn verify_upstream_tls() -> Result<(), Error> {
Ok(())
}

#[cfg_attr(not(feature = "online-tests"), ignore)]
#[test]
async fn session_root_cas() -> Result<(), Error> {
// host cannot validate cert
let resp = Session::builder()
.authtoken_from_env()
.root_cas("host")?
.connect()
.await;
assert!(resp.is_err());
let err_str = resp.err().unwrap().to_string();
tracing::debug!(?err_str);
assert!(err_str.contains("tls")); // tls issue

// default of 'trusted' cannot validate the marketing site
let resp = Session::builder()
.authtoken_from_env()
.server_addr("ngrok.com:443")?
.connect()
.await;
assert!(resp.is_err());
let err_str = resp.err().unwrap().to_string();
tracing::debug!(?err_str);
assert!(err_str.contains("tls")); // tls issue

// "host" certs can validate the marketing site's let's encrypt cert
let resp = Session::builder()
.authtoken_from_env()
.root_cas("host")?
.server_addr("ngrok.com:443")?
.connect()
.await;
assert!(resp.is_err());
let err_str = resp.err().unwrap().to_string();
tracing::debug!(?err_str);
assert!(!err_str.contains("tls")); // not a tls problem

// use the trusted cert, this should connect
Session::builder()
.authtoken_from_env()
.root_cas("trusted")?
.connect()
.await?;

// use the default cert, this should connect
Session::builder()
.authtoken_from_env()
.root_cas("assets/ngrok.ca.crt")?
.connect()
.await?;

Ok(())
}

#[cfg_attr(not(feature = "online-tests"), ignore)]
#[test]
async fn session_ca_cert() -> Result<(), Error> {
Expand Down
42 changes: 41 additions & 1 deletion ngrok/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use futures::{
use futures_rustls::rustls::{
self,
pki_types,
RootCertStore,
};
use hyper::{
client::HttpConnector,
Expand All @@ -37,7 +38,10 @@ use hyper_proxy::{
};
use muxado::heartbeat::HeartbeatConfig;
pub use muxado::heartbeat::HeartbeatHandler;
use once_cell::sync::OnceCell;
use once_cell::sync::{
Lazy,
OnceCell,
};
use regex::Regex;
use rustls_pemfile::Item;
use thiserror::Error;
Expand Down Expand Up @@ -530,6 +534,25 @@ impl SessionBuilder {
Ok(self)
}

/// Sets the file path to a default certificate in PEM format to validate ngrok Session TLS connections.
/// Setting to "trusted" is the default, using the ngrok CA certificate.
/// Setting to "host" will verify using the certificates on the host operating system.
/// A client config set via tls_config after calling root_cas will override this value.
///
/// Corresponds to the [root_cas parameter in the ngrok docs]
///
/// [root_cas parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#root_cas
pub fn root_cas(&mut self, root_cas: impl Into<String>) -> Result<&mut Self, io::Error> {
match root_cas.into().clone().as_str() {
"trusted" => self.ca_cert = None,
"host" => self.tls_config = Some(host_certs_tls_config().map_err(|e| e.kind())?),
v => {
std::fs::read(v).map(|root_cas| self.ca_cert = Some(Bytes::from(root_cas)))?;
}
}
Ok(self)
}

/// Sets the default certificate in PEM format to validate ngrok Session TLS connections.
/// A client config set via tls_config will override this value.
///
Expand Down Expand Up @@ -1007,6 +1030,23 @@ impl Session {
}
}

pub(crate) fn host_certs_tls_config() -> Result<rustls::ClientConfig, &'static io::Error> {
// The root certificate store, lazily loaded once.
static ROOT_STORE: Lazy<Result<RootCertStore, io::Error>> = Lazy::new(|| {
let der_certs = rustls_native_certs::load_native_certs()?
.into_iter()
.collect::<Vec<_>>();
let mut root_store = RootCertStore::empty();
root_store.add_parsable_certificates(der_certs);
Ok(root_store)
});

let root_store = ROOT_STORE.as_ref()?;
Ok(rustls::ClientConfig::builder()
.with_root_certificates(root_store.clone())
.with_no_client_auth())
}

async fn accept_one(
incoming: &mut IncomingStreams,
inner: &ArcSwap<SessionInner>,
Expand Down
22 changes: 4 additions & 18 deletions ngrok/src/tunnel_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use futures_rustls::rustls::{
self,
pki_types,
ClientConfig,
RootCertStore,
};
#[cfg(feature = "hyper")]
use hyper::{
Expand Down Expand Up @@ -221,15 +220,6 @@ fn tls_config(
app_protocol: Option<String>,
verify_upstream_tls: bool,
) -> Result<Arc<ClientConfig>, &'static io::Error> {
// The root certificate store, lazily loaded once.
static ROOT_STORE: Lazy<Result<RootCertStore, io::Error>> = Lazy::new(|| {
let der_certs = rustls_native_certs::load_native_certs()?
.into_iter()
.collect::<Vec<_>>();
let mut root_store = RootCertStore::empty();
root_store.add_parsable_certificates(der_certs);
Ok(root_store)
});
// A hashmap of tls client configs for different configurations.
// There won't need to be a lot of variation among these, and we'll want to
// reuse them as much as we can, which is why we initialize them all once
Expand All @@ -239,18 +229,14 @@ fn tls_config(
#[allow(clippy::type_complexity)]
static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
Lazy::new(|| {
let root_store = ROOT_STORE.as_ref()?;
Ok(std::ops::Range {
std::ops::Range {
start: 0,
end: TlsFlags::FLAG_MAX.bits() + 1,
}
.map(|p| {
let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;

let mut config = ClientConfig::builder()
.with_root_certificates(root_store.clone())
.with_no_client_auth();
let mut config = crate::session::host_certs_tls_config()?;
if !verify_upstream_tls {
config.dangerous().set_certificate_verifier(Arc::new(
danger::NoCertificateVerification::new(provider::default_provider()),
Expand All @@ -262,9 +248,9 @@ fn tls_config(
.alpn_protocols
.extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
}
(p, Arc::new(config))
Ok((p, Arc::new(config)))
})
.collect())
.collect()
});

let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
Expand Down

0 comments on commit ffc45ab

Please sign in to comment.