diff --git a/ngrok/Cargo.toml b/ngrok/Cargo.toml index d78de91..214dbae 100644 --- a/ngrok/Cargo.toml +++ b/ngrok/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ngrok" -version = "0.14.0-pre.2" +version = "0.14.0-pre.3" edition = "2021" license = "MIT OR Apache-2.0" description = "The ngrok agent SDK" @@ -32,6 +32,7 @@ regex = "1.7.3" tokio-socks = "0.5.1" hyper-proxy = "0.9.1" url = "2.4.0" +rustls-native-certs = "0.6.3" [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.45.0", features = ["Win32_Foundation"] } diff --git a/ngrok/examples/mingrok.rs b/ngrok/examples/mingrok.rs index da20d19..4ff4ba2 100644 --- a/ngrok/examples/mingrok.rs +++ b/ngrok/examples/mingrok.rs @@ -11,6 +11,7 @@ use futures::{ use ngrok::prelude::*; use tokio::sync::oneshot; use tracing::info; +use url::Url; #[tokio::main] async fn main() -> Result<(), Error> { @@ -21,7 +22,8 @@ async fn main() -> Result<(), Error> { let forwards_to = std::env::args() .nth(1) - .ok_or_else(|| anyhow::anyhow!("missing forwarding address"))?; + .ok_or_else(|| anyhow::anyhow!("missing forwarding address")) + .and_then(|s| Ok(Url::parse(&s)?))?; loop { let (stop_tx, stop_rx) = oneshot::channel(); @@ -55,18 +57,13 @@ async fn main() -> Result<(), Error> { .connect() .await? .http_endpoint() - .forwards_to(&forwards_to) + .forwards_to(forwards_to.as_str()) .listen() .await?; - info!(url = tun.url(), forwards_to, "started tunnel"); + info!(url = tun.url(), %forwards_to, "started tunnel"); - let mut fut = if forwards_to.contains('/') { - tun.forward_pipe(&forwards_to) - } else { - tun.forward_http(&forwards_to) - } - .fuse(); + let mut fut = TunnelExt::forward(&mut tun, forwards_to.clone()).fuse(); let mut stop_rx = stop_rx.fuse(); let mut restart_rx = restart_rx.fuse(); diff --git a/ngrok/src/tunnel.rs b/ngrok/src/tunnel.rs index a82cda7..17e1a60 100644 --- a/ngrok/src/tunnel.rs +++ b/ngrok/src/tunnel.rs @@ -102,6 +102,8 @@ macro_rules! tunnel_trait { fn forwards_to(&self) -> &str; /// Returns the arbitrary metadata string for this tunnel. fn metadata(&self) -> &str; + /// Returns the protocol for this tunnel. + fn proto(&self) -> &str; /// Close the tunnel. /// /// This is an RPC call that must be `.await`ed. @@ -286,6 +288,10 @@ macro_rules! make_tunnel_type { fn metadata(&self) -> &str { self.inner.metadata() } + + fn proto(&self) -> &str { + self.inner.proto() + } } impl $wrapper { diff --git a/ngrok/src/tunnel_ext.rs b/ngrok/src/tunnel_ext.rs index 05b389c..f20cc7f 100644 --- a/ngrok/src/tunnel_ext.rs +++ b/ngrok/src/tunnel_ext.rs @@ -1,3 +1,5 @@ +#[cfg(not(target_os = "windows"))] +use std::borrow::Cow; #[cfg(target_os = "windows")] use std::time::Duration; #[cfg(feature = "hyper")] @@ -7,10 +9,14 @@ use std::{ }; use std::{ io, - net::SocketAddr, - path::Path, + sync::Arc, }; +use async_rustls::rustls::{ + self, + ClientConfig, + RootCertStore, +}; use async_trait::async_trait; use futures::stream::TryStreamExt; #[cfg(feature = "hyper")] @@ -21,6 +27,7 @@ use hyper::{ Response, StatusCode, }; +use once_cell::sync::Lazy; #[cfg(target_os = "windows")] use tokio::net::windows::named_pipe::ClientOptions; #[cfg(not(target_os = "windows"))] @@ -33,98 +40,222 @@ use tokio::{ AsyncRead, AsyncWrite, }, - net::{ - TcpStream, - ToSocketAddrs, - }, + net::TcpStream, task::JoinHandle, }; +use tokio_util::compat::{ + FuturesAsyncReadCompatExt, + TokioAsyncReadCompatExt, +}; use tracing::{ debug, field, - instrument, - trace, + info_span, warn, Instrument, Span, }; +use url::Url; #[cfg(target_os = "windows")] use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY; use crate::{ prelude::*, + session::IoStream, Conn, }; -impl TunnelExt for T where T: Tunnel {} +impl TunnelExt for T where T: Tunnel + Send {} /// Extension methods auto-implemented for all tunnel types #[async_trait] -pub trait TunnelExt: Tunnel { - /// Forward incoming tunnel connections to the provided TCP address. - #[instrument(level = "debug", skip_all, fields(local_addrs))] - async fn forward_tcp(&mut self, addr: impl ToSocketAddrs + Send) -> Result<(), io::Error> { - forward_conns(self, addr, |_, _| {}).await - } - - /// Forward incoming tunnel connections to the provided TCP address. +pub trait TunnelExt: Tunnel + Send { + /// Forward incoming tunnel connections to the provided url based on its + /// scheme. + /// This currently supports http, https, tls, and tcp on all platforms, unix + /// sockets on unix platforms, and named pipes on Windows via the "pipe" + /// scheme. /// - /// Provides slightly nicer errors when the backend is unavailable. - #[cfg(feature = "hyper")] - #[instrument(level = "debug", skip_all, fields(local_addrs))] - async fn forward_http(&mut self, addr: impl ToSocketAddrs + Send) -> Result<(), io::Error> { - forward_conns(self, addr, |e, c| drop(serve_gateway_error(e, c))).await - } + /// Unix socket URLs can be formatted as `unix://path/to/socket` or + /// `unix:path/to/socket` for relative paths or as `unix:///path/to/socket` or + /// `unix:/path/to/socket` for absolute paths. + /// + /// Windows named pipe URLs can be formatted as `pipe:mypipename` or + /// `pipe://host/mypipename`. If no host is provided, as with + /// `pipe:///mypipename` or `pipe:/mypipename`, the leading slash will be + /// preserved. + #[tracing::instrument(skip_all, fields(tunnel_id = self.id(), url = %url))] + async fn forward(&mut self, url: Url) -> Result<(), io::Error> { + loop { + let tunnel_conn = if let Some(conn) = self + .try_next() + .await + .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))? + { + conn + } else { + return Ok(()); + }; + + let span = info_span!( + "forward_one", + remote_addr = %tunnel_conn.remote_addr(), + forward_addr = field::Empty + ); + + debug!(parent: &span, "accepted tunnel connection"); + + let local_conn = match connect(self, &tunnel_conn, &url) + .instrument(span.clone()) + .await + { + Ok(conn) => conn, + Err(error) => { + warn!(%error, "error establishing local connection"); + + span.in_scope(|| on_err(self, error, tunnel_conn)); + + continue; + } + }; - /// Forward incoming tunnel connections to the provided file socket path. - /// On Linux/Darwin addr can be a unix domain socket path, e.g. "/tmp/ngrok.sock". - /// On Windows addr can be a named pipe, e.g. "\\.\pipe\ngrok_pipe". - #[instrument(level = "debug", skip_all, fields(path))] - async fn forward_pipe(&mut self, addr: impl AsRef + Send) -> Result<(), io::Error> { - forward_pipe_conns(self, addr, |_, _| {}).await + debug!(parent: &span, "established local connection, joining streams"); + + span.in_scope(|| join_streams(tunnel_conn, local_conn)); + } } } -async fn forward_conns(this: &mut T, addr: A, mut on_err: F) -> Result<(), io::Error> -where - T: Tunnel + ?Sized, - A: ToSocketAddrs, - F: FnMut(io::Error, Conn), -{ - let span = Span::current(); - let addrs = tokio::net::lookup_host(addr).await?.collect::>(); - span.record("local_addrs", field::debug(&addrs)); - trace!("looked up local addrs"); - loop { - trace!("waiting for new tunnel connection"); - if !handle_one(this, addrs.as_slice(), &mut on_err).await? { - debug!("listener closed, exiting"); - break; - } +fn on_err(tunnel: &T, err: io::Error, conn: Conn) { + match tunnel.proto() { + #[cfg(feature = "hyper")] + "http" | "https" => drop(serve_gateway_error(err, conn)), + _ => {} } - Ok(()) } -async fn forward_pipe_conns( - this: &mut T, - addr: impl AsRef, - mut on_err: F, -) -> Result<(), io::Error> -where - T: Tunnel + ?Sized, - F: FnMut(io::Error, Conn), -{ - let span = Span::current(); - let path = addr.as_ref(); - span.record("path", field::debug(&path)); - loop { - trace!("waiting for new tunnel connection"); - if !handle_one_pipe(this, path, &mut on_err).await? { - debug!("listener closed, exiting"); - break; +fn tls_config() -> Result, &'static io::Error> { + static CONFIG: Lazy, io::Error>> = Lazy::new(|| { + let der_certs = rustls_native_certs::load_native_certs()? + .into_iter() + .map(|c| c.0) + .collect::>(); + let der_certs = der_certs.as_slice(); + let mut root_store = RootCertStore::empty(); + root_store.add_parsable_certificates(der_certs); + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + Ok(Arc::new(config)) + }); + + Ok(CONFIG.as_ref()?.clone()) +} + +// Establish the connection to forward the tunnel stream to. +// Takes the tunnel and connection to make additional decisions on how to wrap +// the forwarded connection, i.e. reordering tls termination and proxyproto. +// Note: this additional wrapping logic currently unimplemented. +async fn connect( + _tunnel: &mut T, + _conn: &Conn, + url: &Url, +) -> Result, io::Error> { + let host = url.host_str().unwrap_or("localhost"); + Ok(match url.scheme() { + "tcp" => { + let port = url.port().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("missing port for tcp forwarding url {url}"), + ) + })?; + let conn = connect_tcp(host, port).in_current_span().await?; + Box::new(conn) + } + + "http" => { + let port = url.port().unwrap_or(80); + let conn = connect_tcp(host, port).in_current_span().await?; + Box::new(conn) + } + + "https" | "tls" => { + let port = url.port().unwrap_or(443); + let conn = connect_tcp(host, port).in_current_span().await?; + + // TODO: if the tunnel uses proxyproto, wrap conn here before terminating tls + + let domain = rustls::ServerName::try_from(host) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + Box::new( + async_rustls::TlsConnector::from(tls_config().map_err(|e| e.kind())?) + .connect(domain, conn.compat()) + .await? + .compat(), + ) + } + + #[cfg(not(target_os = "windows"))] + "unix" => { + // + let mut addr = Cow::Borrowed(url.path()); + if let Some(host) = url.host_str() { + // note: if host exists, there should always be a leading / in + // the path, but we should consider it a relative path. + addr = Cow::Owned(format!("{host}{addr}")); + } + Box::new(UnixStream::connect(&*addr).await?) + } + + #[cfg(target_os = "windows")] + "pipe" => { + let mut pipe_name = url.path(); + if url.host_str().is_some() { + pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name); + } + if pipe_name.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("missing pipe name in forwarding url {url}"), + )); + } + let host = url + .host_str() + // Consider localhost to mean "." for the pipe name + .map(|h| if h == "localhost" { "." } else { h }) + .unwrap_or("."); + // Finally, assemble the full name. + let addr = format!("\\\\{host}\\pipe\\{pipe_name}"); + // loop behavior copied from docs + // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html + let local_conn = loop { + match ClientOptions::new().open(&addr) { + Ok(client) => break client, + Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(error) => return Err(error), + } + + time::sleep(Duration::from_millis(50)).await; + }; + Box::new(local_conn) } + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unrecognized scheme in forwarding url: {url}"), + )) + } + }) +} + +async fn connect_tcp(host: &str, port: u16) -> Result { + let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?; + if let Ok(addr) = conn.peer_addr() { + Span::current().record("forward_addr", field::display(addr)); } - Ok(()) + Ok(conn) } fn join_streams( @@ -142,114 +273,11 @@ fn join_streams( ) } -#[instrument(level = "debug", skip_all, fields(remote_addr, local_addr))] -async fn handle_one( - this: &mut T, - addrs: &[SocketAddr], - on_error: F, -) -> Result -where - T: Tunnel + ?Sized, - F: FnOnce(io::Error, Conn), -{ - let span = Span::current(); - let tunnel_conn = if let Some(conn) = this - .try_next() - .await - .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))? - { - conn - } else { - return Ok(false); - }; - - span.record("remote_addr", field::debug(tunnel_conn.remote_addr())); - - trace!("accepted tunnel connection"); - - let local_conn = match TcpStream::connect(addrs).await { - Ok(conn) => conn, - Err(error) => { - warn!(%error, "error establishing local connection"); - - on_error(error, tunnel_conn); - - return Ok(true); - } - }; - span.record("local_addr", field::debug(local_conn.peer_addr().unwrap())); - - debug!("established local connection, joining streams"); - - join_streams(tunnel_conn, local_conn); - Ok(true) -} - -#[instrument(level = "debug", skip_all, fields(remote_addr, local_addr))] -async fn handle_one_pipe(this: &mut T, addr: &Path, on_error: F) -> Result -where - T: Tunnel + ?Sized, - F: FnOnce(io::Error, Conn), -{ - let span = Span::current(); - let tunnel_conn = if let Some(conn) = this - .try_next() - .await - .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))? - { - conn - } else { - return Ok(false); - }; - - span.record("remote_addr", field::debug(tunnel_conn.remote_addr())); - - trace!("accepted tunnel connection"); - - #[cfg(not(target_os = "windows"))] - { - let local_conn = match UnixStream::connect(addr).await { - Ok(conn) => conn, - Err(error) => { - warn!(%error, "error establishing local unix connection"); - on_error(error, tunnel_conn); - return Ok(true); - } - }; - span.record("local_addr", field::debug(local_conn.peer_addr().unwrap())); - debug!("established local connection, joining streams"); - join_streams(tunnel_conn, local_conn); - } - #[cfg(target_os = "windows")] - { - // loop behavior copied from docs - // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html - let local_conn = loop { - match ClientOptions::new().open(addr) { - Ok(client) => break client, - Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), - Err(error) => { - warn!(%error, "error establishing local named pipe connection"); - on_error(error, tunnel_conn); - return Ok(true); - } - } - - time::sleep(Duration::from_millis(50)).await; - }; - span.record("local_addr", field::debug(addr)); - debug!("established local connection, joining streams"); - join_streams(tunnel_conn, local_conn); - } - - Ok(true) -} - #[cfg(feature = "hyper")] #[allow(dead_code)] fn serve_gateway_error( err: impl fmt::Display + Send + 'static, - stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, + conn: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, ) -> JoinHandle<()> { tokio::spawn( async move { @@ -257,7 +285,7 @@ fn serve_gateway_error( .http1_only(true) .http1_keep_alive(false) .serve_connection( - stream, + conn, service_fn(move |_req| { debug!("serving bad gateway error"); let mut resp =