From 1e006a29163a8c39df6c07896f78fef4cda0ac14 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Mon, 8 Apr 2024 16:46:14 -0700 Subject: [PATCH] feat!: introduce ServerConfig --- api/tests/disconnect-body.rs | 14 +- async-std/src/runtime.rs | 10 +- async-std/src/server/tcp.rs | 17 +- async-std/src/server/unix.rs | 49 +- aws-lambda/src/lib.rs | 20 +- client/src/client.rs | 7 +- client/src/conn.rs | 582 +--------------- client/src/conn/implementation.rs | 516 ++++++++++++++ client/src/conn/unexpected_status_error.rs | 48 ++ client/tests/timeout.rs | 7 +- forwarding/src/lib.rs | 4 +- http/examples/conn-example.rs | 55 +- http/examples/http.rs | 12 +- http/examples/tokio-http.rs | 10 +- http/examples/unsend.rs | 17 +- http/src/body.rs | 2 +- http/src/conn.rs | 637 +----------------- http/src/conn/implementation.rs | 404 +++++++++++ http/src/lib.rs | 90 +-- http/src/received_body.rs | 2 +- http/src/received_body/chunked.rs | 184 +++-- http/src/server_config.rs | 117 ++++ http/src/state_set/entry.rs | 93 +++ http/src/synthetic.rs | 19 +- http/src/upgrade.rs | 21 +- http/tests/corpus.rs | 10 +- http/tests/one_hundred_continue.rs | 10 +- http/tests/unsafe_headers.rs | 8 +- http/tests/use_cases.rs | 12 +- logger/Cargo.toml | 1 + logger/src/lib.rs | 36 +- macros/tests/derive.rs | 4 +- native-tls/src/client.rs | 2 +- rustls/src/client.rs | 2 +- server-common/Cargo.toml | 1 + server-common/src/acceptor.rs | 9 + server-common/src/config.rs | 159 +++-- server-common/src/config_ext.rs | 232 ------- server-common/src/lib.rs | 6 +- server-common/src/running_config.rs | 135 ++++ server-common/src/runtime.rs | 19 +- .../src/runtime/object_safe_runtime.rs | 6 + server-common/src/runtime/runtime_trait.rs | 9 + server-common/src/server.rs | 158 ++--- server-common/src/server_handle.rs | 47 +- smol/examples/smol.rs | 14 +- smol/src/runtime.rs | 10 +- smol/src/server/tcp.rs | 23 +- smol/src/server/unix.rs | 48 +- static/src/handler.rs | 4 +- static/src/lib.rs | 2 +- tera/src/tera_handler.rs | 4 +- testing/Cargo.toml | 1 + testing/src/lib.rs | 27 +- testing/src/runtimeless/runtime.rs | 2 +- testing/src/runtimeless/server.rs | 24 +- testing/src/server_connector.rs | 34 +- testing/src/test_conn.rs | 14 +- testing/src/with_server.rs | 2 +- tokio/src/runtime.rs | 10 +- tokio/src/server/tcp.rs | 15 +- tokio/src/server/unix.rs | 58 +- trillium/Cargo.toml | 3 +- trillium/examples/state.rs | 24 +- trillium/src/conn.rs | 19 +- trillium/src/handler.rs | 123 ++-- trillium/src/info.rs | 154 ++--- trillium/src/init.rs | 86 +++ trillium/src/lib.rs | 3 + trillium/src/shared_state.rs | 36 + trillium/src/state.rs | 31 +- trillium/tests/init.rs | 56 ++ trillium/tests/liveness.rs | 16 +- websockets/src/bidirectional_stream.rs | 2 +- websockets/src/websocket_connection.rs | 25 +- 75 files changed, 2413 insertions(+), 2260 deletions(-) create mode 100644 client/src/conn/implementation.rs create mode 100644 client/src/conn/unexpected_status_error.rs create mode 100644 http/src/conn/implementation.rs create mode 100644 http/src/server_config.rs create mode 100644 http/src/state_set/entry.rs delete mode 100644 server-common/src/config_ext.rs create mode 100644 server-common/src/running_config.rs create mode 100644 trillium/src/init.rs create mode 100644 trillium/src/shared_state.rs create mode 100644 trillium/tests/init.rs diff --git a/api/tests/disconnect-body.rs b/api/tests/disconnect-body.rs index 4275963018..c1c994241f 100644 --- a/api/tests/disconnect-body.rs +++ b/api/tests/disconnect-body.rs @@ -100,20 +100,10 @@ async fn establish_server(handler: impl Handler) -> (ServerHandle, impl AsyncWri let handle = trillium_testing::config().with_port(0).spawn(handler); let info = handle.info().await; - let port = info.tcp_socket_addr().map_or_else( - || { - info.listener_description() - .split(":") - .nth(1) - .unwrap() - .parse() - .unwrap() - }, - |x| x.port(), - ); + let url = info.state::().unwrap(); let client = ArcedConnector::new(client_config()) - .connect(&format!("http://localhost:{port}").parse().unwrap()) + .connect(url) .await .unwrap(); (handle, client) diff --git a/async-std/src/runtime.rs b/async-std/src/runtime.rs index a4f1f19207..09fb37e15e 100644 --- a/async-std/src/runtime.rs +++ b/async-std/src/runtime.rs @@ -30,6 +30,14 @@ impl RuntimeTrait for AsyncStdRuntime { fn block_on(&self, fut: Fut) -> Fut::Output { async_std::task::block_on(fut) } + + #[cfg(unix)] + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + signal_hook_async_std::Signals::new(signals).unwrap() + } } impl AsyncStdRuntime { @@ -81,6 +89,6 @@ impl AsyncStdRuntime { impl From for Runtime { fn from(value: AsyncStdRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } diff --git a/async-std/src/server/tcp.rs b/async-std/src/server/tcp.rs index aff2e294d5..3acb3530d0 100644 --- a/async-std/src/server/tcp.rs +++ b/async-std/src/server/tcp.rs @@ -1,6 +1,6 @@ use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::net::{TcpListener, TcpStream}; -use std::{env, io::Result}; +use std::io::Result; use trillium::Info; use trillium_server_common::Server; @@ -21,24 +21,19 @@ impl From for AsyncStdServer { impl Server for AsyncStdServer { type Runtime = AsyncStdRuntime; type Transport = AsyncStdTransport; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); async fn accept(&mut self) -> Result { self.0.accept().await.map(|(t, _)| t.into()) } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp: std::net::TcpListener) -> Self { Self(tcp.into()) } - fn info(&self) -> Info { - self.0.local_addr().unwrap().into() + fn init(&self, info: &mut Info) { + if let Ok(socket_addr) = self.0.local_addr() { + info.insert_state(socket_addr); + } } fn runtime() -> Self::Runtime { diff --git a/async-std/src/server/unix.rs b/async-std/src/server/unix.rs index 0fad8fab9d..08aefc8c65 100644 --- a/async-std/src/server/unix.rs +++ b/async-std/src/server/unix.rs @@ -2,13 +2,12 @@ use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::{ net::{TcpListener, TcpStream}, os::unix::net::{UnixListener, UnixStream}, - stream::StreamExt, }; -use std::{env, io::Result}; +use std::io::Result; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, + Server, }; /// Tcp/Unix Trillium server adapter for Async-Std @@ -41,31 +40,6 @@ impl Server for AsyncStdServer { type Runtime = AsyncStdRuntime; type Transport = Binding, AsyncStdTransport>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - - async fn handle_signals(swansong: Swansong) { - use signal_hook::consts::signal::*; - use signal_hook_async_std::Signals; - - let signals = Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - let mut signals = signals.fuse(); - while signals.next().await.is_some() { - if swansong.state().is_shutting_down() { - eprintln!("\nSecond interrupt, shutting down harshly"); - std::process::exit(1); - } else { - println!("\nShutting down gracefully.\nControl-C again to force."); - swansong.shut_down(); - } - } - } - async fn accept(&mut self) -> Result { match &self.0 { Tcp(t) => t @@ -80,18 +54,27 @@ impl Server for AsyncStdServer { } } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp: std::net::TcpListener) -> Self { Self(Tcp(tcp.into())) } - fn listener_from_unix(tcp: std::os::unix::net::UnixListener) -> Self { + fn from_unix(tcp: std::os::unix::net::UnixListener) -> Self { Self(Unix(tcp.into())) } - fn info(&self) -> Info { + fn init(&self, info: &mut Info) { match &self.0 { - Tcp(t) => t.local_addr().unwrap().into(), - Unix(u) => u.local_addr().unwrap().into(), + Tcp(t) => { + if let Ok(socket_addr) = t.local_addr() { + info.insert_state(socket_addr); + } + } + + Unix(u) => { + if let Ok(socket_addr) = u.local_addr() { + info.insert_state(socket_addr); + } + } } } diff --git a/aws-lambda/src/lib.rs b/aws-lambda/src/lib.rs index 6af50086de..438bf4ba53 100644 --- a/aws-lambda/src/lib.rs +++ b/aws-lambda/src/lib.rs @@ -23,7 +23,7 @@ use lamedh_runtime::{Context, Handler as AwsHandler}; use std::{future::Future, pin::Pin, sync::Arc}; use tokio::runtime; use trillium::{Conn, Handler}; -use trillium_http::{Conn as HttpConn, Synthetic}; +use trillium_http::{Conn as HttpConn, ServerConfig, Synthetic}; mod context; pub use context::LambdaConnExt; @@ -36,14 +36,19 @@ mod response; use response::{AlbMultiHeadersResponse, AlbResponse, LambdaResponse}; #[derive(Debug)] -struct HandlerWrapper(Arc); +struct HandlerWrapper(Arc, Arc); impl AwsHandler for HandlerWrapper { type Error = std::io::Error; type Fut = Pin> + Send + 'static>>; fn call(&mut self, request: LambdaRequest, context: Context) -> Self::Fut { - Box::pin(handler_fn(request, context, Arc::clone(&self.0))) + Box::pin(handler_fn( + request, + context, + Arc::clone(&self.0), + Arc::clone(&self.1), + )) } } @@ -56,17 +61,18 @@ async fn handler_fn( request: LambdaRequest, context: Context, handler: Arc, + server_config: Arc, ) -> std::io::Result { match request { LambdaRequest::Alb(request) => { - let mut conn = request.into_conn().await; + let mut conn = request.into_conn().await.with_server_config(server_config); conn.state_mut().insert(LambdaContext::new(context)); let conn = run_handler(conn, handler).await; Ok(LambdaResponse::Alb(AlbResponse::from_conn(conn).await)) } LambdaRequest::AlbMultiHeaders(request) => { - let mut conn = request.into_conn().await; + let mut conn = request.into_conn().await.with_server_config(server_config); conn.state_mut().insert(LambdaContext::new(context)); let conn = run_handler(conn, handler).await; Ok(LambdaResponse::AlbMultiHeaders( @@ -81,9 +87,9 @@ async fn handler_fn( This function will poll pending until the server shuts down. */ pub async fn run_async(mut handler: impl Handler) { - let mut info = "aws lambda".into(); + let mut info = ServerConfig::default().into(); handler.init(&mut info).await; - lamedh_runtime::run(HandlerWrapper(Arc::new(handler))) + lamedh_runtime::run(HandlerWrapper(Arc::new(handler), Arc::new(info.into()))) .await .unwrap() } diff --git a/client/src/client.rs b/client/src/client.rs index 23bb2a84df..f86a02ed0f 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -2,7 +2,7 @@ use crate::{Conn, IntoUrl, Pool, USER_AGENT}; use std::{fmt::Debug, sync::Arc, time::Duration}; use trillium_http::{ transport::BoxedTransport, HeaderName, HeaderValues, Headers, KnownHeaderName, Method, - ReceivedBodyState, Version::Http1_1, + ReceivedBodyState, TypeSet, Version::Http1_1, }; use trillium_server_common::{ url::{Origin, Url}, @@ -69,9 +69,9 @@ pub(crate) fn default_request_headers() -> Headers { impl Client { /// builds a new client from this `Connector` - pub fn new(config: impl Connector) -> Self { + pub fn new(connector: impl Connector) -> Self { Self { - config: ArcedConnector::new(config), + config: ArcedConnector::new(connector), pool: None, base: None, default_headers: Arc::new(default_request_headers()), @@ -171,6 +171,7 @@ impl Client { timeout: self.timeout, http_version: Http1_1, max_head_length: 8 * 1024, + state: TypeSet::new(), } } diff --git a/client/src/conn.rs b/client/src/conn.rs index 0f7127b00b..e5218c15f3 100644 --- a/client/src/conn.rs +++ b/client/src/conn.rs @@ -1,27 +1,20 @@ -use crate::{pool::PoolEntry, util::encoding, Pool}; +use crate::{util::encoding, Pool}; use encoding_rs::Encoding; -use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; -use memchr::memmem::Finder; -use size::{Base, Size}; -use std::{ - fmt::{self, Debug, Display, Formatter}, - future::{Future, IntoFuture}, - io::{ErrorKind, Write}, - ops::{Deref, DerefMut}, - pin::Pin, - time::Duration, -}; +use std::{net::SocketAddr, time::Duration}; use trillium_http::{ - transport::{BoxedTransport, Transport}, - Body, Error, HeaderName, HeaderValues, Headers, - KnownHeaderName::{Connection, ContentLength, Expect, Host, TransferEncoding}, - Method, ReceivedBody, ReceivedBodyState, Result, Status, Upgrade, Version, + transport::BoxedTransport, Body, Buffer, HeaderName, HeaderValues, Headers, Method, + ReceivedBody, ReceivedBodyState, Status, TypeSet, Version, }; use trillium_server_common::{ url::{Origin, Url}, - ArcedConnector, Connector, + ArcedConnector, Transport, }; +mod implementation; +mod unexpected_status_error; + +pub use unexpected_status_error::UnexpectedStatusError; + /** A wrapper error for [`trillium_http::Error`] or [`serde_json::Error`]. Only available when the `json` crate feature is @@ -32,7 +25,7 @@ enabled. pub enum ClientSerdeError { /// A [`trillium_http::Error`] #[error(transparent)] - HttpError(#[from] Error), + HttpError(#[from] trillium_http::Error), /// A [`serde_json::Error`] #[error(transparent)] @@ -53,37 +46,19 @@ pub struct Conn { pub(crate) status: Option, pub(crate) request_body: Option, pub(crate) pool: Option>, - pub(crate) buffer: trillium_http::Buffer, + pub(crate) buffer: Buffer, pub(crate) response_body_state: ReceivedBodyState, pub(crate) config: ArcedConnector, pub(crate) headers_finalized: bool, pub(crate) timeout: Option, pub(crate) http_version: Version, pub(crate) max_head_length: usize, + pub(crate) state: TypeSet, } /// default http user-agent header pub const USER_AGENT: &str = concat!("trillium-client/", env!("CARGO_PKG_VERSION")); -impl Debug for Conn { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Conn") - .field("url", &self.url) - .field("method", &self.method) - .field("request_headers", &self.request_headers) - .field("response_headers", &self.response_headers) - .field("status", &self.status) - .field("request_body", &self.request_body) - .field("pool", &self.pool) - .field("buffer", &String::from_utf8_lossy(&self.buffer)) - .field("response_body_state", &self.response_body_state) - .field("config", &self.config) - .field("http_version", &self.http_version) - .field("max_head_length", &self.max_head_length) - .finish() - } -} - impl Conn { /// borrow the request headers pub fn request_headers(&self) -> &Headers { @@ -322,9 +297,9 @@ impl Conn { /** retrieves the url for this conn. ``` - use trillium_testing::client_config; use trillium_client::Client; - let client = Client::from(client_config()); + let client = Client::from(trillium_testing::client_config()); + let conn = client.get("http://localhost:9080"); let url = conn.url(); //<- @@ -399,7 +374,7 @@ impl Conn { Attempt to deserialize the response body. Note that this consumes the body content. */ #[cfg(feature = "json")] - pub async fn response_json(&mut self) -> std::result::Result + pub async fn response_json(&mut self) -> Result where T: serde::de::DeserializeOwned, { @@ -407,19 +382,6 @@ impl Conn { Ok(serde_json::from_str(&body)?) } - pub(crate) fn response_content_length(&self) -> Option { - if self.status == Some(Status::NoContent) - || self.status == Some(Status::NotModified) - || self.method == Method::Head - { - Some(0) - } else { - self.response_headers - .get_str(ContentLength) - .and_then(|c| c.parse().ok()) - } - } - /** returns the status code for this conn. if the conn has not yet been sent, this will be None. @@ -467,7 +429,7 @@ impl Conn { }); ``` */ - pub fn success(self) -> std::result::Result { + pub fn success(self) -> Result { match self.status() { Some(status) if status.is_success() => Ok(self), _ => Err(self.into()), @@ -487,7 +449,7 @@ impl Conn { } /// attempts to retrieve the connected peer address - pub fn peer_addr(&self) -> Option { + pub fn peer_addr(&self) -> Option { self.transport .as_ref() .and_then(|t| t.peer_addr().ok().flatten()) @@ -516,505 +478,29 @@ impl Conn { self.http_version } - // --- everything below here is private --- - - fn finalize_headers(&mut self) -> Result<()> { - if self.headers_finalized { - return Ok(()); - } - - let host = self.url.host_str().ok_or(Error::UnexpectedUriFormat)?; - - self.request_headers.try_insert_with(Host, || { - self.url - .port() - .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")) - }); - - if self.pool.is_none() { - self.request_headers.try_insert(Connection, "close"); - } - - match self.body_len() { - Some(0) => {} - Some(len) => { - self.request_headers.insert(Expect, "100-continue"); - self.request_headers.insert(ContentLength, len.to_string()); - } - None => { - self.request_headers.insert(Expect, "100-continue"); - self.request_headers.insert(TransferEncoding, "chunked"); - } - } - - self.headers_finalized = true; - Ok(()) - } - - fn body_len(&self) -> Option { - if let Some(ref body) = self.request_body { - body.len() - } else { - Some(0) - } - } - - async fn find_pool_candidate(&self, head: &[u8]) -> Result> { - let mut byte = [0]; - if let Some(pool) = &self.pool { - for mut candidate in pool.candidates(&self.url.origin()) { - if poll_once(candidate.read(&mut byte)).await.is_none() - && candidate.write_all(head).await.is_ok() - { - return Ok(Some(candidate)); - } - } - } - Ok(None) - } - - async fn connect_and_send_head(&mut self) -> Result<()> { - if self.transport.is_some() { - return Err(Error::Io(std::io::Error::new( - ErrorKind::AlreadyExists, - "conn already connected", - ))); - } - - let head = self.build_head().await?; - - let transport = match self.find_pool_candidate(&head).await? { - Some(transport) => { - log::debug!("reusing connection to {:?}", transport.peer_addr()?); - transport - } - - None => { - let mut transport = self.config.connect(&self.url).await?; - log::debug!("opened new connection to {:?}", transport.peer_addr()?); - transport.write_all(&head).await?; - transport - } - }; - - self.transport = Some(transport); - Ok(()) - } - - async fn build_head(&mut self) -> Result> { - let mut buf = Vec::with_capacity(128); - let url = &self.url; - let method = self.method; - write!(buf, "{method} ")?; - - if method == Method::Connect { - let host = url.host_str().ok_or(Error::UnexpectedUriFormat)?; - - let port = url - .port_or_known_default() - .ok_or(Error::UnexpectedUriFormat)?; - - write!(buf, "{host}:{port}")?; - } else { - write!(buf, "{}", url.path())?; - if let Some(query) = url.query() { - write!(buf, "?{query}")?; - } - } - - write!(buf, " {}\r\n", self.http_version)?; - - for (name, values) in &self.request_headers { - if !name.is_valid() { - return Err(Error::InvalidHeaderName); - } - - for value in values { - if !value.is_valid() { - return Err(Error::InvalidHeaderValue(name.to_owned())); - } - write!(buf, "{name}: ")?; - buf.extend_from_slice(value.as_ref()); - write!(buf, "\r\n")?; - } - } - - write!(buf, "\r\n")?; - log::trace!( - "{}", - std::str::from_utf8(&buf).unwrap().replace("\r\n", "\r\n> ") - ); - - Ok(buf) - } - - fn transport(&mut self) -> &mut BoxedTransport { - self.transport.as_mut().unwrap() - } - - async fn read_head(&mut self) -> Result { - let Self { - buffer, - transport: Some(transport), - .. - } = self - else { - return Err(Error::Closed); - }; - - let mut len = buffer.len(); - let mut search_start = 0; - let finder = Finder::new(b"\r\n\r\n"); - - if len > 0 { - if let Some(index) = finder.find(buffer) { - return Ok(index + 4); - } - search_start = len.saturating_sub(3); - } - - loop { - buffer.expand(); - let bytes = transport.read(&mut buffer[len..]).await?; - len += bytes; - - let search = finder.find(&buffer[search_start..len]); - - if let Some(index) = search { - buffer.truncate(len); - return Ok(search_start + index + 4); - } - - search_start = len.saturating_sub(3); - - if bytes == 0 { - if len == 0 { - return Err(Error::Closed); - } else { - return Err(Error::InvalidHead); - } - } - - if len >= self.max_head_length { - return Err(Error::HeadersTooLong); - } - } - } - - #[cfg(not(feature = "parse"))] - async fn parse_head(&mut self) -> Result<()> { - const MAX_HEADERS: usize = 128; - use crate::HeaderValue; - use std::str::FromStr; - - let head_offset = self.read_head().await?; - let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; - let mut httparse_res = httparse::Response::new(&mut headers); - let parse_result = - httparse_res - .parse(&self.buffer[..head_offset]) - .map_err(|e| match e { - httparse::Error::HeaderName => Error::InvalidHeaderName, - httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), - httparse::Error::Status => Error::InvalidStatus, - httparse::Error::TooManyHeaders => Error::HeadersTooLong, - httparse::Error::Version => Error::InvalidVersion, - _ => Error::InvalidHead, - })?; - - match parse_result { - httparse::Status::Complete(n) if n == head_offset => {} - _ => return Err(Error::InvalidHead), - } - - self.status = httparse_res.code.map(|code| code.try_into().unwrap()); - - for header in httparse_res.headers { - let header_name = HeaderName::from_str(header.name)?; - let header_value = HeaderValue::from(header.value.to_owned()); - self.response_headers.append(header_name, header_value); - } - - self.buffer.ignore_front(head_offset); - - self.validate_response_headers()?; - Ok(()) - } - - #[cfg(feature = "parse")] - async fn parse_head(&mut self) -> Result<()> { - use std::str; - - let head_offset = self.read_head().await?; - - let space = memchr::memchr(b' ', &self.buffer[..head_offset]).ok_or(Error::InvalidHead)?; - self.http_version = str::from_utf8(&self.buffer[..space]) - .map_err(|_| Error::InvalidHead)? - .parse() - .map_err(|_| Error::InvalidHead)?; - self.status = Some(str::from_utf8(&self.buffer[space + 1..space + 4])?.parse()?); - let end_of_first_line = 2 + Finder::new("\r\n") - .find(&self.buffer[..head_offset]) - .ok_or(Error::InvalidHead)?; - - self.response_headers - .extend_parse(&self.buffer[end_of_first_line..head_offset]) - .map_err(|_| Error::InvalidHead)?; - - self.buffer.ignore_front(head_offset); - - self.validate_response_headers()?; - Ok(()) - } - - async fn send_body_and_parse_head(&mut self) -> Result<()> { - if self - .request_headers - .eq_ignore_ascii_case(Expect, "100-continue") - { - log::trace!("Expecting 100-continue"); - self.parse_head().await?; - if self.status == Some(Status::Continue) { - self.status = None; - log::trace!("Received 100-continue, sending request body"); - } else { - self.request_body.take(); - log::trace!( - "Received a status code other than 100-continue, not sending request body" - ); - return Ok(()); - } - } - - self.send_body().await?; - loop { - self.parse_head().await?; - if self.status == Some(Status::Continue) { - self.status = None; - } else { - break; - } - } - - Ok(()) - } - - async fn send_body(&mut self) -> Result<()> { - if let Some(mut body) = self.request_body.take() { - io::copy(&mut body, self.transport()).await?; - } - Ok(()) - } - - fn validate_response_headers(&self) -> Result<()> { - let content_length = self.response_headers.has_header(ContentLength); - - let transfer_encoding_chunked = self - .response_headers - .eq_ignore_ascii_case(TransferEncoding, "chunked"); - - if content_length && transfer_encoding_chunked { - Err(Error::UnexpectedHeader(ContentLength.into())) - } else { - Ok(()) - } - } - - fn is_keep_alive(&self) -> bool { - self.response_headers - .eq_ignore_ascii_case(Connection, "keep-alive") - } - - async fn finish_reading_body(&mut self) { - if self.response_body_state != ReceivedBodyState::End { - let body = self.response_body(); - match body.drain().await { - Ok(drain) => log::debug!( - "drained {}", - Size::from_bytes(drain).format().with_base(Base::Base10) - ), - Err(e) => log::warn!("failed to drain body, {:?}", e), - } - } - } - - async fn exec(&mut self) -> Result<()> { - self.finalize_headers()?; - self.connect_and_send_head().await?; - self.send_body_and_parse_head().await?; - Ok(()) - } -} - -impl Drop for Conn { - fn drop(&mut self) { - if !self.is_keep_alive() { - return; - } - - let Some(transport) = self.transport.take() else { - return; - }; - let Ok(Some(peer_addr)) = transport.peer_addr() else { - return; - }; - let Some(pool) = self.pool.take() else { return }; - - let origin = self.url.origin(); - - if self.response_body_state == ReceivedBodyState::End { - log::trace!("response body has been read to completion, checking transport back into pool for {}", &peer_addr); - pool.insert(origin, PoolEntry::new(transport, None)); - } else { - let content_length = self.response_content_length(); - let buffer = std::mem::take(&mut self.buffer); - let response_body_state = self.response_body_state; - let encoding = encoding(&self.response_headers); - self.config.runtime().spawn(async move { - let mut response_body = ReceivedBody::new( - content_length, - buffer, - transport, - response_body_state, - None, - encoding, - ); - - match io::copy(&mut response_body, io::sink()).await { - Ok(bytes) => { - let transport = response_body.take_transport().unwrap(); - log::trace!( - "read {} bytes in order to recycle conn for {}", - bytes, - &peer_addr - ); - pool.insert(origin, PoolEntry::new(transport, None)); - } - - Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror), - }; - }); - } - } -} - -impl From for Body { - fn from(conn: Conn) -> Body { - let received_body: ReceivedBody<'static, _> = conn.into(); - received_body.into() - } -} - -impl From for ReceivedBody<'static, BoxedTransport> { - fn from(mut conn: Conn) -> Self { - let _ = conn.finalize_headers(); - let origin = conn.url.origin(); - - let on_completion = - conn.pool - .take() - .map(|pool| -> Box { - Box::new(move |transport| { - pool.insert(origin.clone(), PoolEntry::new(transport, None)); - }) - }); - - ReceivedBody::new( - conn.response_content_length(), - std::mem::take(&mut conn.buffer), - conn.transport.take().unwrap(), - conn.response_body_state, - on_completion, - conn.response_encoding(), - ) - } -} - -impl From for Upgrade { - fn from(mut conn: Conn) -> Self { - Upgrade::new( - std::mem::take(&mut conn.request_headers), - conn.url.path().to_string(), - conn.method, - conn.transport.take().unwrap(), - std::mem::take(&mut conn.buffer), - ) - } -} - -impl IntoFuture for Conn { - type Output = Result; - - type IntoFuture = Pin + Send + 'static>>; - - fn into_future(mut self) -> Self::IntoFuture { - Box::pin(async move { - if let Some(duration) = self.timeout { - self.config - .runtime() - .timeout(duration, self.exec()) - .await - .ok_or(Error::TimedOut("Conn", duration))??; - } else { - self.exec().await?; - } - Ok(self) - }) - } -} - -impl<'conn> IntoFuture for &'conn mut Conn { - type Output = Result<()>; - - type IntoFuture = Pin + Send + 'conn>>; - - fn into_future(self) -> Self::IntoFuture { - Box::pin(async move { - self.exec().await?; - Ok(()) - }) + /// add state to the client conn and return self + pub fn with_state(mut self, state: T) -> Self { + self.insert_state(state); + self } -} -/// An unexpected http status code was received. Transform this back -/// into the conn with [`From::from`]/[`Into::into`]. -/// -/// Currently only returned by [`Conn::success`] -#[derive(Debug)] -pub struct UnexpectedStatusError(Box); -impl From for UnexpectedStatusError { - fn from(value: Conn) -> Self { - Self(Box::new(value)) + /// add state to the client conn, returning any previously set state of this type + pub fn insert_state(&mut self, state: T) -> Option { + self.state.insert(state) } -} -impl From for Conn { - fn from(value: UnexpectedStatusError) -> Self { - *value.0 + /// borrow state + pub fn state(&self) -> Option<&T> { + self.state.get() } -} -impl Deref for UnexpectedStatusError { - type Target = Conn; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl DerefMut for UnexpectedStatusError { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + /// borrow state mutably + pub fn state_mut(&mut self) -> Option<&mut T> { + self.state.get_mut() } -} -impl std::error::Error for UnexpectedStatusError {} -impl Display for UnexpectedStatusError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self.status() { - Some(status) => f.write_fmt(format_args!( - "expected a success (2xx) status code, but got {status}" - )), - None => f.write_str("expected a status code to be set, but none was"), - } + /// take state + pub fn take_state(&mut self) -> Option { + self.state.take() } } diff --git a/client/src/conn/implementation.rs b/client/src/conn/implementation.rs new file mode 100644 index 0000000000..e31d8c6607 --- /dev/null +++ b/client/src/conn/implementation.rs @@ -0,0 +1,516 @@ +use super::Conn; +use crate::{pool::PoolEntry, util::encoding}; +use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; +use memchr::memmem::Finder; +use size::{Base, Size}; +use std::{ + fmt::{self, Debug, Formatter}, + future::{Future, IntoFuture}, + io::{ErrorKind, Write}, + pin::Pin, +}; +use trillium_http::{ + transport::BoxedTransport, + Body, Error, + KnownHeaderName::{Connection, ContentLength, Expect, Host, TransferEncoding}, + Method, ReceivedBody, ReceivedBodyState, Result, Status, TypeSet, Upgrade, +}; +use trillium_server_common::{Connector, Transport}; + +impl Conn { + fn finalize_headers(&mut self) -> Result<()> { + if self.headers_finalized { + return Ok(()); + } + + let host = self.url.host_str().ok_or(Error::UnexpectedUriFormat)?; + + self.request_headers.try_insert_with(Host, || { + self.url + .port() + .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")) + }); + + if self.pool.is_none() { + self.request_headers.try_insert(Connection, "close"); + } + + match self.body_len() { + Some(0) => {} + Some(len) => { + self.request_headers.insert(Expect, "100-continue"); + self.request_headers.insert(ContentLength, len); + } + None => { + self.request_headers.insert(Expect, "100-continue"); + self.request_headers.insert(TransferEncoding, "chunked"); + } + } + + self.headers_finalized = true; + Ok(()) + } + + fn body_len(&self) -> Option { + if let Some(ref body) = self.request_body { + body.len() + } else { + Some(0) + } + } + + async fn find_pool_candidate(&self, head: &[u8]) -> Result> { + let mut byte = [0]; + if let Some(pool) = &self.pool { + for mut candidate in pool.candidates(&self.url.origin()) { + if poll_once(candidate.read(&mut byte)).await.is_none() + && candidate.write_all(head).await.is_ok() + { + return Ok(Some(candidate)); + } + } + } + Ok(None) + } + + async fn connect_and_send_head(&mut self) -> Result<()> { + if self.transport.is_some() { + return Err(Error::Io(std::io::Error::new( + ErrorKind::AlreadyExists, + "conn already connected", + ))); + } + + let head = self.build_head().await?; + + let transport = match self.find_pool_candidate(&head).await? { + Some(transport) => { + log::debug!("reusing connection to {:?}", transport.peer_addr()?); + transport + } + + None => { + let mut transport = self.config.connect(&self.url).await?; + log::debug!("opened new connection to {:?}", transport.peer_addr()?); + transport.write_all(&head).await?; + transport + } + }; + + self.transport = Some(transport); + Ok(()) + } + + async fn build_head(&mut self) -> Result> { + let mut buf = Vec::with_capacity(128); + let url = &self.url; + let method = self.method; + write!(buf, "{method} ")?; + + if method == Method::Connect { + let host = url.host_str().ok_or(Error::UnexpectedUriFormat)?; + + let port = url + .port_or_known_default() + .ok_or(Error::UnexpectedUriFormat)?; + + write!(buf, "{host}:{port}")?; + } else { + write!(buf, "{}", url.path())?; + if let Some(query) = url.query() { + write!(buf, "?{query}")?; + } + } + + write!(buf, " HTTP/1.1\r\n")?; + + for (name, values) in &self.request_headers { + if !name.is_valid() { + return Err(Error::InvalidHeaderName); + } + + for value in values { + if !value.is_valid() { + return Err(Error::InvalidHeaderValue(name.to_owned())); + } + write!(buf, "{name}: ")?; + buf.extend_from_slice(value.as_ref()); + write!(buf, "\r\n")?; + } + } + + write!(buf, "\r\n")?; + log::trace!( + "{}", + std::str::from_utf8(&buf).unwrap().replace("\r\n", "\r\n> ") + ); + + Ok(buf) + } + + fn transport(&mut self) -> &mut BoxedTransport { + self.transport.as_mut().unwrap() + } + + async fn read_head(&mut self) -> Result { + let Self { + buffer, + transport: Some(transport), + .. + } = self + else { + return Err(Error::Closed); + }; + + let mut len = buffer.len(); + let mut search_start = 0; + let finder = Finder::new(b"\r\n\r\n"); + + if len > 0 { + if let Some(index) = finder.find(buffer) { + return Ok(index + 4); + } + search_start = len.saturating_sub(3); + } + + loop { + buffer.expand(); + let bytes = transport.read(&mut buffer[len..]).await?; + len += bytes; + + let search = finder.find(&buffer[search_start..len]); + + if let Some(index) = search { + buffer.truncate(len); + return Ok(search_start + index + 4); + } + + search_start = len.saturating_sub(3); + + if bytes == 0 { + if len == 0 { + return Err(Error::Closed); + } else { + return Err(Error::InvalidHead); + } + } + + if len >= self.max_head_length { + return Err(Error::HeadersTooLong); + } + } + } + + #[cfg(not(feature = "parse"))] + async fn parse_head(&mut self) -> Result<()> { + const MAX_HEADERS: usize = 128; + use crate::{HeaderName, HeaderValue}; + use std::str::FromStr; + + let head_offset = self.read_head().await?; + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut httparse_res = httparse::Response::new(&mut headers); + let parse_result = + httparse_res + .parse(&self.buffer[..head_offset]) + .map_err(|e| match e { + httparse::Error::HeaderName => Error::InvalidHeaderName, + httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), + httparse::Error::Status => Error::InvalidStatus, + httparse::Error::TooManyHeaders => Error::HeadersTooLong, + httparse::Error::Version => Error::InvalidVersion, + _ => Error::InvalidHead, + })?; + + match parse_result { + httparse::Status::Complete(n) if n == head_offset => {} + _ => return Err(Error::InvalidHead), + } + + self.status = httparse_res.code.map(|code| code.try_into().unwrap()); + + for header in httparse_res.headers { + let header_name = HeaderName::from_str(header.name)?; + let header_value = HeaderValue::from(header.value.to_owned()); + self.response_headers.append(header_name, header_value); + } + + self.buffer.ignore_front(head_offset); + + self.validate_response_headers()?; + Ok(()) + } + + #[cfg(feature = "parse")] + async fn parse_head(&mut self) -> Result<()> { + use std::str; + + let head_offset = self.read_head().await?; + + let space = memchr::memchr(b' ', &self.buffer[..head_offset]).ok_or(Error::InvalidHead)?; + self.http_version = str::from_utf8(&self.buffer[..space]) + .map_err(|_| Error::InvalidHead)? + .parse() + .map_err(|_| Error::InvalidHead)?; + self.status = Some(str::from_utf8(&self.buffer[space + 1..space + 4])?.parse()?); + let end_of_first_line = 2 + Finder::new("\r\n") + .find(&self.buffer[..head_offset]) + .ok_or(Error::InvalidHead)?; + + self.response_headers + .extend_parse(&self.buffer[end_of_first_line..head_offset]) + .map_err(|_| Error::InvalidHead)?; + + self.buffer.ignore_front(head_offset); + + self.validate_response_headers()?; + Ok(()) + } + + async fn send_body_and_parse_head(&mut self) -> Result<()> { + if self + .request_headers + .eq_ignore_ascii_case(Expect, "100-continue") + { + log::trace!("Expecting 100-continue"); + self.parse_head().await?; + if self.status == Some(Status::Continue) { + self.status = None; + log::trace!("Received 100-continue, sending request body"); + } else { + self.request_body.take(); + log::trace!( + "Received a status code other than 100-continue, not sending request body" + ); + return Ok(()); + } + } + + self.send_body().await?; + loop { + self.parse_head().await?; + if self.status == Some(Status::Continue) { + self.status = None; + } else { + break; + } + } + + Ok(()) + } + + async fn send_body(&mut self) -> Result<()> { + if let Some(mut body) = self.request_body.take() { + io::copy(&mut body, self.transport()).await?; + } + Ok(()) + } + + fn validate_response_headers(&self) -> Result<()> { + let content_length = self.response_headers.has_header(ContentLength); + + let transfer_encoding_chunked = self + .response_headers + .eq_ignore_ascii_case(TransferEncoding, "chunked"); + + if content_length && transfer_encoding_chunked { + Err(Error::UnexpectedHeader(ContentLength.into())) + } else { + Ok(()) + } + } + + pub(super) fn is_keep_alive(&self) -> bool { + self.response_headers + .eq_ignore_ascii_case(Connection, "keep-alive") + } + + pub(super) async fn finish_reading_body(&mut self) { + if self.response_body_state != ReceivedBodyState::End { + let body = self.response_body(); + match body.drain().await { + Ok(drain) => log::debug!( + "drained {}", + Size::from_bytes(drain).format().with_base(Base::Base10) + ), + Err(e) => log::warn!("failed to drain body, {:?}", e), + } + } + } + + async fn exec(&mut self) -> Result<()> { + self.finalize_headers()?; + self.connect_and_send_head().await?; + self.send_body_and_parse_head().await?; + Ok(()) + } + + pub(super) fn response_content_length(&self) -> Option { + if self.status == Some(Status::NoContent) + || self.status == Some(Status::NotModified) + || self.method == Method::Head + { + Some(0) + } else { + self.response_headers + .get_str(ContentLength) + .and_then(|c| c.parse().ok()) + } + } +} + +impl Drop for Conn { + fn drop(&mut self) { + if !self.is_keep_alive() { + return; + } + + let Some(transport) = self.transport.take() else { + return; + }; + let Ok(Some(peer_addr)) = transport.peer_addr() else { + return; + }; + let Some(pool) = self.pool.take() else { return }; + + let origin = self.url.origin(); + + if self.response_body_state == ReceivedBodyState::End { + log::trace!("response body has been read to completion, checking transport back into pool for {}", &peer_addr); + pool.insert(origin, PoolEntry::new(transport, None)); + } else { + let content_length = self.response_content_length(); + let buffer = std::mem::take(&mut self.buffer); + let response_body_state = self.response_body_state; + let encoding = encoding(&self.response_headers); + self.config.runtime().spawn(async move { + let mut response_body = ReceivedBody::new( + content_length, + buffer, + transport, + response_body_state, + None, + encoding, + ); + + match io::copy(&mut response_body, io::sink()).await { + Ok(bytes) => { + let transport = response_body.take_transport().unwrap(); + log::trace!( + "read {} bytes in order to recycle conn for {}", + bytes, + &peer_addr + ); + pool.insert(origin, PoolEntry::new(transport, None)); + } + + Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror), + }; + }); + } + } +} + +impl From for Body { + fn from(conn: Conn) -> Body { + let received_body: ReceivedBody<'static, _> = conn.into(); + received_body.into() + } +} + +impl From for ReceivedBody<'static, BoxedTransport> { + fn from(mut conn: Conn) -> Self { + let _ = conn.finalize_headers(); + let origin = conn.url.origin(); + + let on_completion = + conn.pool + .take() + .map(|pool| -> Box { + Box::new(move |transport| { + pool.insert(origin.clone(), PoolEntry::new(transport, None)); + }) + }); + + ReceivedBody::new( + conn.response_content_length(), + std::mem::take(&mut conn.buffer), + conn.transport.take().unwrap(), + conn.response_body_state, + on_completion, + conn.response_encoding(), + ) + } +} + +impl From for Upgrade { + fn from(mut conn: Conn) -> Self { + Upgrade::new( + std::mem::take(&mut conn.request_headers), + conn.url.path().to_string(), + conn.method, + conn.transport.take().unwrap(), + std::mem::take(&mut conn.buffer), + ) + } +} + +impl IntoFuture for Conn { + type Output = Result; + + type IntoFuture = Pin + Send + 'static>>; + + fn into_future(mut self) -> Self::IntoFuture { + Box::pin(async move { (&mut self).await.map(|()| self) }) + } +} + +impl<'conn> IntoFuture for &'conn mut Conn { + type Output = Result<()>; + + type IntoFuture = Pin + Send + 'conn>>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { + if let Some(duration) = self.timeout { + self.config + .runtime() + .timeout(duration, self.exec()) + .await + .unwrap_or(Err(Error::TimedOut("Conn", duration)))?; + } else { + self.exec().await?; + } + Ok(()) + }) + } +} + +impl Debug for Conn { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Conn") + .field("url", &self.url) + .field("method", &self.method) + .field("request_headers", &self.request_headers) + .field("response_headers", &self.response_headers) + .field("status", &self.status) + .field("request_body", &self.request_body) + .field("pool", &self.pool) + .field("buffer", &String::from_utf8_lossy(&self.buffer)) + .field("response_body_state", &self.response_body_state) + .field("config", &self.config) + .field("state", &self.state) + .finish() + } +} + +impl AsRef for Conn { + fn as_ref(&self) -> &TypeSet { + &self.state + } +} +impl AsMut for Conn { + fn as_mut(&mut self) -> &mut TypeSet { + &mut self.state + } +} diff --git a/client/src/conn/unexpected_status_error.rs b/client/src/conn/unexpected_status_error.rs new file mode 100644 index 0000000000..07bcc8e57b --- /dev/null +++ b/client/src/conn/unexpected_status_error.rs @@ -0,0 +1,48 @@ +use super::Conn; +use std::{ + error::Error, + fmt::{self, Debug, Display, Formatter}, + ops::{Deref, DerefMut}, +}; +/// An unexpected http status code was received. Transform this back +/// into the conn with [`From::from`]/[`Into::into`]. +/// +/// Currently only returned by [`Conn::success`] +#[derive(Debug)] +pub struct UnexpectedStatusError(Box); +impl From for UnexpectedStatusError { + fn from(value: Conn) -> Self { + Self(Box::new(value)) + } +} + +impl From for Conn { + fn from(value: UnexpectedStatusError) -> Self { + *value.0 + } +} + +impl Deref for UnexpectedStatusError { + type Target = Conn; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for UnexpectedStatusError { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Error for UnexpectedStatusError {} +impl Display for UnexpectedStatusError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.status() { + Some(status) => f.write_fmt(format_args!( + "expected a success (2xx) status code, but got {status}" + )), + None => f.write_str("expected a status code to be set, but none was"), + } + } +} diff --git a/client/tests/timeout.rs b/client/tests/timeout.rs index a46ce108ce..9bc7e14f2c 100644 --- a/client/tests/timeout.rs +++ b/client/tests/timeout.rs @@ -1,10 +1,13 @@ use std::time::Duration; use trillium_client::Client; -use trillium_testing::{client_config, runtime, RuntimeTrait}; +use trillium_testing::{client_config, Runtime}; async fn handler(conn: trillium::Conn) -> trillium::Conn { if conn.path() == "/slow" { - runtime().delay(Duration::from_secs(5)).await; + conn.shared_state::() + .unwrap() + .delay(Duration::from_secs(5)) + .await; } conn.ok("ok") } diff --git a/forwarding/src/lib.rs b/forwarding/src/lib.rs index ba4e00b36d..e39eca2d23 100644 --- a/forwarding/src/lib.rs +++ b/forwarding/src/lib.rs @@ -54,7 +54,9 @@ where } impl Debug for TrustFn { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("TrustPredicate").field(&"..").finish() + f.debug_tuple("TrustPredicate") + .field(&format_args!("..")) + .finish() } } diff --git a/http/examples/conn-example.rs b/http/examples/conn-example.rs index eb7e541d7d..69c093ffe5 100644 --- a/http/examples/conn-example.rs +++ b/http/examples/conn-example.rs @@ -1,50 +1,43 @@ -fn main() -> trillium_http::Result<()> { - use async_net::{TcpListener, TcpStream}; - use futures_lite::StreamExt; - use swansong::Swansong; - use trillium_http::{Conn, Result}; +fn main() { + use smol::{net::TcpListener, stream::StreamExt}; + use std::sync::Arc; + use trillium_http::ServerConfig; smol::block_on(async { - let swansong = Swansong::new(); - - let server_swansong = swansong.clone(); - let server = smol::spawn(async move { - let listener = TcpListener::bind("localhost:8001").await?; - let mut incoming = server_swansong.interrupt(listener.incoming()); - - while let Some(Ok(stream)) = incoming.next().await { - let swansong = server_swansong.clone(); - smol::spawn(async move { - Conn::map(stream, swansong, |mut conn: Conn| async move { + let server_config = Arc::new(ServerConfig::default()); + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + println!("listening on http://{local_addr}"); + + let server = smol::spawn({ + let server_config = server_config.clone(); + async move { + let mut incoming = server_config.swansong().interrupt(listener.incoming()); + + while let Some(Ok(stream)) = incoming.next().await { + smol::spawn(server_config.clone().run(stream, |mut conn| async move { conn.set_response_body("hello world"); conn.set_status(200); conn - }) - .await - }) - .detach() + })) + .detach() + } } - - Result::Ok(()) }); - // this example uses the trillium client - // please note that this api is still especially unstable. - // any other http client would work here too use trillium_client::Client; use trillium_smol::ClientConfig; - let client = Client::new(ClientConfig::default()); - let mut client_conn = client.get("http://localhost:8001").await?; + let client = Client::new(ClientConfig::default()).with_base(local_addr); + let mut client_conn = client.get("/").await.unwrap(); assert_eq!(client_conn.status().unwrap(), 200); assert_eq!( - client_conn.response_body().read_string().await?, + client_conn.response_body().read_string().await.unwrap(), "hello world" ); - swansong.shut_down(); // stop the server after one request - server.await?; // wait for the server to shut down + server.await; - Result::Ok(()) + // server_config.shut_down().await; // stop the server after one request }) } diff --git a/http/examples/http.rs b/http/examples/http.rs index 8ceff5cc81..d869085235 100644 --- a/http/examples/http.rs +++ b/http/examples/http.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; + use async_net::{TcpListener, TcpStream}; use futures_lite::prelude::*; -use trillium_http::{Conn, Swansong}; +use trillium_http::{Conn, ServerConfig}; async fn handler(mut conn: Conn) -> Conn { conn.set_status(200); @@ -12,18 +14,18 @@ pub fn main() { env_logger::init(); smol::block_on(async move { - let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); let port = std::env::var("PORT") .unwrap_or("8080".into()) .parse::() .unwrap(); let listener = TcpListener::bind(("0.0.0.0", port)).await.unwrap(); - let mut incoming = swansong.interrupt(listener.incoming()); + let mut incoming = server_config.swansong().interrupt(listener.incoming()); while let Some(Ok(stream)) = incoming.next().await { - let swansong = swansong.clone(); + let server_config = Arc::clone(&server_config); smol::spawn(async move { - match Conn::map(stream, swansong, handler).await { + match server_config.run(stream, handler).await { Ok(Some(_)) => log::info!("upgrade"), Ok(None) => log::info!("closing connection"), Err(e) => log::error!("{:?}", e), diff --git a/http/examples/tokio-http.rs b/http/examples/tokio-http.rs index 7400416025..32c83112e2 100644 --- a/http/examples/tokio-http.rs +++ b/http/examples/tokio-http.rs @@ -1,6 +1,7 @@ use async_compat::Compat; +use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; -use trillium_http::{Conn, Swansong}; +use trillium_http::{Conn, ServerConfig}; async fn handler(mut conn: Conn>) -> Conn> { let body = conn.request_body().await.read_string().await.unwrap(); @@ -15,14 +16,15 @@ async fn handler(mut conn: Conn>) -> Conn> { #[tokio::main] pub async fn main() { env_logger::init(); - let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); + let listener = TcpListener::bind("127.0.0.1:8081").await.unwrap(); loop { match listener.accept().await { Ok((stream, _)) => { - let swansong = swansong.clone(); + let server_config = server_config.clone(); tokio::spawn(async move { - match Conn::map(Compat::new(stream), swansong, handler).await { + match server_config.run(Compat::new(stream), handler).await { Ok(Some(_)) => log::info!("upgrade"), Ok(None) => log::info!("closing connection"), Err(e) => log::error!("{:?}", e), diff --git a/http/examples/unsend.rs b/http/examples/unsend.rs index a753fe4d69..68e503dd46 100644 --- a/http/examples/unsend.rs +++ b/http/examples/unsend.rs @@ -1,7 +1,7 @@ use async_net::{TcpListener, TcpStream}; use futures_lite::prelude::*; -use std::thread; -use trillium_http::{Conn, Swansong}; +use std::{sync::Arc, thread}; +use trillium_http::{Conn, ServerConfig, Swansong}; async fn handler(mut conn: Conn) -> Conn { let rc = std::rc::Rc::new(()); @@ -14,13 +14,15 @@ async fn handler(mut conn: Conn) -> Conn { pub fn main() { env_logger::init(); - let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); let (send, receive) = async_channel::unbounded(); let core_ids = core_affinity::get_core_ids().unwrap(); + + let swansong = Swansong::new(); let handles = core_ids .into_iter() .map(|id| { - let swansong = swansong.clone(); + let server_config = server_config.clone(); let receive = receive.clone(); thread::spawn(move || { if !core_affinity::set_for_current(id) { @@ -28,12 +30,11 @@ pub fn main() { } let executor = async_executor::LocalExecutor::new(); - futures_lite::future::block_on(executor.run(async { + async_io::block_on(executor.run(async { while let Ok(transport) = receive.recv().await { - let swansong = swansong.clone(); - + let server_config = server_config.clone(); let future = async move { - match Conn::map(transport, swansong, handler).await { + match server_config.run(transport, handler).await { Ok(_) => {} Err(e) => log::error!("{e}"), } diff --git a/http/src/body.rs b/http/src/body.rs index 948936ec78..a8a30322b4 100644 --- a/http/src/body.rs +++ b/http/src/body.rs @@ -282,7 +282,7 @@ impl Debug for BodyType { .. } => f .debug_struct("BodyType::Streaming") - .field("async_read", &"..") + .field("async_read", &format_args!("..")) .field("len", &len) .field("done", &done) .field("progress", &progress) diff --git a/http/src/conn.rs b/http/src/conn.rs index 0c85032ee5..7f2e3f2fe9 100644 --- a/http/src/conn.rs +++ b/http/src/conn.rs @@ -1,20 +1,17 @@ use crate::{ after_send::{AfterSend, SendStatus}, - copy, - http_config::DEFAULT_CONFIG, liveness::{CancelOnDisconnect, LivenessFut}, received_body::ReceivedBodyState, util::encoding, - Body, BufWriter, Buffer, ConnectionStatus, Error, Headers, HttpConfig, - KnownHeaderName::{Connection, ContentLength, Date, Expect, Host, Server, TransferEncoding}, - Method, ReceivedBody, Result, Status, Swansong, TypeSet, Upgrade, Version, + Body, Buffer, Headers, + KnownHeaderName::{Connection, ContentLength, Date, Host, TransferEncoding}, + Method, ReceivedBody, ServerConfig, Status, Swansong, TypeSet, Version, }; use encoding_rs::Encoding; use futures_lite::{ future, - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + io::{AsyncRead, AsyncWrite}, }; -use memchr::memmem::Finder; use std::{ fmt::{self, Debug, Formatter}, future::Future, @@ -24,6 +21,7 @@ use std::{ sync::Arc, time::{Instant, SystemTime}, }; +mod implementation; /// Default Server header pub const SERVER: &str = concat!("trillium/", env!("CARGO_PKG_VERSION")); @@ -35,6 +33,7 @@ the request and the response, and holds the transport over which the response will be sent. */ pub struct Conn { + pub(crate) server_config: Arc, pub(crate) request_headers: Headers, pub(crate) response_headers: Headers, pub(crate) path: String, @@ -47,18 +46,15 @@ pub struct Conn { pub(crate) buffer: Buffer, pub(crate) request_body_state: ReceivedBodyState, pub(crate) secure: bool, - pub(crate) swansong: Swansong, pub(crate) after_send: AfterSend, pub(crate) start_time: Instant, pub(crate) peer_ip: Option, - pub(crate) http_config: HttpConfig, - pub(crate) shared_state: Option>, } impl Debug for Conn { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Conn") - .field("http_config", &self.http_config) + .field("server_config", &self.server_config) .field("request_headers", &self.request_headers) .field("response_headers", &self.response_headers) .field("path", &self.path) @@ -66,14 +62,12 @@ impl Debug for Conn { .field("status", &self.status) .field("version", &self.version) .field("state", &self.state) - .field("shared_state", &self.shared_state) .field("response_body", &self.response_body) - .field("transport", &"..") - .field("buffer", &"..") + .field("transport", &format_args!("..")) + .field("buffer", &format_args!("..")) .field("request_body_state", &self.request_body_state) .field("secure", &self.secure) - .field("swansong", &self.swansong) - .field("after_send", &"..") + .field("after_send", &format_args!("..")) .field("start_time", &self.start_time) .field("peer_ip", &self.peer_ip) .finish() @@ -84,143 +78,6 @@ impl Conn where Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, { - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// Provides a default [`HttpConfig`] - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - - pub async fn map( - transport: Transport, - swansong: Swansong, - handler: F, - ) -> Result>> - where - F: FnMut(Conn) -> Fut, - Fut: Future>, - { - Self::map_with_config(DEFAULT_CONFIG, transport, swansong, handler).await - } - - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - pub async fn map_with_config( - http_config: HttpConfig, - transport: Transport, - swansong: Swansong, - handler: F, - ) -> Result>> - where - F: FnMut(Conn) -> Fut, - Fut: Future>, - { - Self::map_with_config_and_shared_state(http_config, transport, swansong, None, handler) - .await - } - - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// The `shared_state` `Arc` is available provided to all Conns on this transport if - /// provided. - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - pub async fn map_with_config_and_shared_state( - http_config: HttpConfig, - transport: Transport, - swansong: Swansong, - shared_state: Option>, - mut handler: F, - ) -> Result>> - where - F: FnMut(Conn) -> Fut, - Fut: Future>, - { - let mut conn = Conn::new_internal( - http_config, - transport, - Vec::with_capacity(http_config.request_buffer_initial_len).into(), - swansong, - shared_state, - ) - .await?; - - loop { - conn = match handler(conn).await.send().await? { - ConnectionStatus::Upgrade(upgrade) => return Ok(Some(upgrade)), - ConnectionStatus::Close => return Ok(None), - ConnectionStatus::Conn(next) => next, - } - } - } - - async fn send(mut self) -> Result> { - let mut output_buffer = Vec::with_capacity(self.http_config.response_buffer_len); - self.write_headers(&mut output_buffer)?; - - let mut bufwriter = BufWriter::new_with_buffer(output_buffer, &mut self.transport); - - if self.method != Method::Head - && !matches!(self.status, Some(Status::NotModified | Status::NoContent)) - { - if let Some(body) = self.response_body.take() { - copy(body, &mut bufwriter, self.http_config.copy_loops_per_yield).await?; - } - } - - bufwriter.flush().await?; - self.after_send.call(true.into()); - self.finish().await - } - /// returns a read-only reference to the [state /// typemap](TypeSet) for this conn /// @@ -233,21 +90,13 @@ where /// returns a mutable reference to the [state /// typemap](TypeSet) for this conn - /// - /// stability note: this is not unlikely to be removed at some - /// point, as this may end up being more of a trillium concern - /// than a `trillium_http` concern pub fn state_mut(&mut self) -> &mut TypeSet { &mut self.state } /// Returns the shared state on this conn, if set - /// - /// stability note: this is not unlikely to be removed at some - /// point, as this may end up being more of a trillium concern - /// than a `trillium_http` concern - pub fn shared_state(&self) -> Option<&TypeSet> { - self.shared_state.as_deref() + pub fn shared_state(&self) -> &TypeSet { + &self.server_config.shared_state } /// returns a reference to the request headers @@ -346,21 +195,6 @@ where self.request_headers.insert(Host, host); } - // pub fn url(&self) -> Result { - // let path = self.path(); - // let host = self.host().unwrap_or_else(|| String::from("_")); - // let method = self.method(); - // if path.starts_with("http://") || path.starts_with("https://") { - // Ok(Url::parse(path)?) - // } else if path.starts_with('/') { - // Ok(Url::parse(&format!("http://{}{}", host, path))?) - // } else if method == &Method::Connect { - // Ok(Url::parse(&format!("http://{}/", path))?) - // } else { - // Err(Error::UnexpectedUriFormat) - // } - // } - /** Sets the response body to anything that is [`impl Into`][Body]. @@ -494,27 +328,6 @@ where future::poll_once(LivenessFut::new(self)).await.is_some() } - fn needs_100_continue(&self) -> bool { - self.request_body_state == ReceivedBodyState::Start - && self.version != Version::Http1_0 - && self - .request_headers - .eq_ignore_ascii_case(Expect, "100-continue") - } - - #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)] - fn build_request_body(&mut self) -> ReceivedBody<'_, Transport> { - ReceivedBody::new_with_config( - self.request_content_length().ok().flatten(), - &mut self.buffer, - &mut self.transport, - &mut self.request_body_state, - None, - encoding(&self.request_headers), - &self.http_config, - ) - } - /** returns the [`encoding_rs::Encoding`] for this request, as determined from the mime-type charset, if available @@ -573,190 +386,7 @@ where /// this to gracefully stop long-running futures and streams /// inside of handler functions pub fn swansong(&self) -> Swansong { - self.swansong.clone() - } - - fn validate_headers(request_headers: &Headers) -> Result<()> { - let content_length = request_headers.has_header(ContentLength); - let transfer_encoding_chunked = - request_headers.eq_ignore_ascii_case(TransferEncoding, "chunked"); - - if content_length && transfer_encoding_chunked { - Err(Error::UnexpectedHeader(ContentLength.into())) - } else { - Ok(()) - } - } - - /// # Create a new `Conn` - /// - /// This function creates a new conn from the provided - /// [`Transport`][crate::transport::Transport], as well as any - /// bytes that have already been read from the transport, and a - /// [`Swansong`] instance that will be used to signal graceful - /// shutdown. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - pub async fn new(transport: Transport, bytes: Vec, swansong: Swansong) -> Result { - Self::new_internal(DEFAULT_CONFIG, transport, bytes.into(), swansong, None).await - } - - #[cfg(not(feature = "parse"))] - async fn new_internal( - http_config: HttpConfig, - mut transport: Transport, - mut buffer: Buffer, - swansong: Swansong, - shared_state: Option>, - ) -> Result { - use crate::{HeaderName, HeaderValue}; - use httparse::{Request, EMPTY_HEADER}; - use std::str::FromStr; - - let (head_size, start_time) = - Self::head(&mut transport, &mut buffer, &swansong, &http_config).await?; - - let mut headers = vec![EMPTY_HEADER; http_config.max_headers]; - let mut httparse_req = Request::new(&mut headers); - - let status = httparse_req.parse(&buffer[..]).map_err(|e| match e { - httparse::Error::HeaderName => Error::InvalidHeaderName, - httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), - httparse::Error::Status => Error::InvalidStatus, - httparse::Error::TooManyHeaders => Error::HeadersTooLong, - httparse::Error::Version => Error::InvalidVersion, - _ => Error::InvalidHead, - })?; - - if status.is_partial() { - return Err(Error::InvalidHead); - } - - let method = match httparse_req.method { - Some(method) => match method.parse() { - Ok(method) => method, - Err(_) => return Err(Error::UnrecognizedMethod(method.to_string())), - }, - None => return Err(Error::MissingMethod), - }; - - let version = match httparse_req.version { - Some(0) => Version::Http1_0, - Some(1) => Version::Http1_1, - _ => return Err(Error::InvalidVersion), - }; - - let mut request_headers = Headers::new(); - for header in httparse_req.headers { - let header_name = HeaderName::from_str(header.name)?; - let header_value = HeaderValue::from(header.value.to_owned()); - request_headers.append(header_name, header_value); - } - - Self::validate_headers(&request_headers)?; - - let path = httparse_req - .path - .ok_or(Error::RequestPathMissing)? - .to_owned(); - log::trace!("received:\n{method} {path} {version}\n{request_headers}"); - - let mut response_headers = Headers::new(); - response_headers.insert(Server, SERVER); - - buffer.ignore_front(head_size); - - Ok(Self { - transport, - request_headers, - method, - version, - path, - buffer, - response_headers, - status: None, - state: TypeSet::new(), - response_body: None, - request_body_state: ReceivedBodyState::Start, - secure: false, - swansong, - after_send: AfterSend::default(), - start_time, - peer_ip: None, - http_config, - shared_state, - }) - } - - #[cfg(feature = "parse")] - async fn new_internal( - http_config: HttpConfig, - mut transport: Transport, - mut buffer: Buffer, - swansong: Swansong, - shared_state: Option>, - ) -> Result { - let (head_size, start_time) = - Self::head(&mut transport, &mut buffer, &swansong, &http_config).await?; - - let first_line_index = Finder::new(b"\r\n") - .find(&buffer) - .ok_or(Error::InvalidHead)?; - - let mut spaces = memchr::memchr_iter(b' ', &buffer[..first_line_index]); - let first_space = spaces.next().ok_or(Error::MissingMethod)?; - let method = Method::parse(&buffer[0..first_space])?; - let second_space = spaces.next().ok_or(Error::RequestPathMissing)?; - let path = str::from_utf8(&buffer[first_space + 1..second_space]) - .map_err(|_| Error::RequestPathMissing)? - .to_string(); - if path.is_empty() { - return Err(Error::InvalidHead); - } - let version = Version::parse(&buffer[second_space + 1..first_line_index])?; - if !matches!(version, Version::Http1_1 | Version::Http1_0) { - return Err(Error::UnsupportedVersion(version)); - } - - let request_headers = Headers::parse(&buffer[first_line_index + 2..head_size])?; - - Self::validate_headers(&request_headers)?; - - let mut response_headers = Headers::new(); - response_headers.insert(Server, SERVER); - - buffer.ignore_front(head_size); - - Ok(Self { - transport, - request_headers, - method, - version, - path, - buffer, - response_headers, - status: None, - state: TypeSet::new(), - response_body: None, - request_body_state: ReceivedBodyState::Start, - secure: false, - swansong, - after_send: AfterSend::default(), - start_time, - peer_ip: None, - http_config, - shared_state, - }) + self.server_config.swansong.clone() } /// predicate function to indicate whether the connection is @@ -804,7 +434,7 @@ where } } - if self.swansong.state().is_shutting_down() { + if self.server_config.swansong.state().is_shutting_down() { self.response_headers.insert(Connection, "close"); } } @@ -832,229 +462,34 @@ where self.start_time } - async fn send_100_continue(&mut self) -> Result<()> { - log::trace!("sending 100-continue"); - Ok(self - .transport - .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") - .await?) - } - - async fn head( - transport: &mut Transport, - buf: &mut Buffer, - swansong: &Swansong, - http_config: &HttpConfig, - ) -> Result<(usize, Instant)> { - let mut len = 0; - let mut start_with_read = buf.is_empty(); - let mut instant = None; - let finder = Finder::new(b"\r\n\r\n"); - loop { - if len >= http_config.head_max_len { - return Err(Error::HeadersTooLong); - } - - let bytes = if start_with_read { - buf.expand(); - if len == 0 { - swansong - .interrupt(transport.read(buf)) - .await - .ok_or(Error::Closed)?? - } else { - transport.read(&mut buf[len..]).await? - } - } else { - start_with_read = true; - buf.len() - }; - - if instant.is_none() { - instant = Some(Instant::now()); - } - - let search_start = len.max(3) - 3; - let search = finder.find(&buf[search_start..]); - - if let Some(index) = search { - buf.truncate(len + bytes); - return Ok((search_start + index + 4, instant.unwrap())); - } - - len += bytes; - - if bytes == 0 { - return if len == 0 { - Err(Error::Closed) - } else { - Err(Error::InvalidHead) - }; - } - } - } - - async fn next(mut self) -> Result { - if !self.needs_100_continue() || self.request_body_state != ReceivedBodyState::Start { - self.build_request_body().drain().await?; - } - Conn::new_internal( - self.http_config, - self.transport, - self.buffer, - self.swansong, - self.shared_state, - ) - .await - } - - fn should_close(&self) -> bool { - let request_connection = self.request_headers.get_lower(Connection); - let response_connection = self.response_headers.get_lower(Connection); - - match ( - request_connection.as_deref(), - response_connection.as_deref(), - ) { - (Some("keep-alive"), Some("keep-alive")) => false, - (Some("close"), _) | (_, Some("close")) => true, - _ => self.version == Version::Http1_0, - } - } - - fn should_upgrade(&self) -> bool { - (self.method() == Method::Connect && self.status == Some(Status::Ok)) - || self.status == Some(Status::SwitchingProtocols) - } - - async fn finish(self) -> Result> { - if self.should_close() { - Ok(ConnectionStatus::Close) - } else if self.should_upgrade() { - Ok(ConnectionStatus::Upgrade(self.into())) - } else { - match self.next().await { - Err(Error::Closed) => { - log::trace!("connection closed by client"); - Ok(ConnectionStatus::Close) - } - Err(e) => Err(e), - Ok(conn) => Ok(ConnectionStatus::Conn(conn)), - } - } - } - - fn request_content_length(&self) -> Result> { - if self - .request_headers - .eq_ignore_ascii_case(TransferEncoding, "chunked") - { - Ok(None) - } else if let Some(cl) = self.request_headers.get_str(ContentLength) { - cl.parse() - .map(Some) - .map_err(|_| Error::InvalidHeaderValue(ContentLength.into())) - } else { - Ok(Some(0)) - } - } - - fn body_len(&self) -> Option { - match self.response_body { - Some(ref body) => body.len(), - None => Some(0), - } - } - - fn write_headers(&mut self, output_buffer: &mut Vec) -> Result<()> { - use std::io::Write; - let status = self.status().unwrap_or(Status::NotFound); - - write!( - output_buffer, - "{} {} {}\r\n", - self.version, - status as u16, - status.canonical_reason() - )?; - - self.finalize_headers(); - - log::trace!( - "sending:\n{} {}\n{}", - self.version, - status, - &self.response_headers - ); - - for (name, values) in &self.response_headers { - if name.is_valid() { - for value in values { - if value.is_valid() { - write!(output_buffer, "{name}: ")?; - output_buffer.extend_from_slice(value.as_ref()); - write!(output_buffer, "\r\n")?; - } else { - log::error!("skipping invalid header value {value:?} for header {name}"); - } - } - } else { - log::error!("skipping invalid header with name {name:?}"); - } - } - - write!(output_buffer, "\r\n")?; - Ok(()) - } - /// applies a mapping function from one transport to another. This /// is particularly useful for boxing the transport. unless you're /// sure this is what you're looking for, you probably don't want /// to be using this - pub fn map_transport( + pub fn map_transport( self, - f: impl Fn(Transport) -> T, - ) -> Conn { - let Conn { - request_headers, - response_headers, - path, - status, - version, - state, - transport, - buffer, - request_body_state, - secure, - method, - response_body, - swansong, - after_send, - start_time, - peer_ip, - http_config, - shared_state, - } = self; - + f: impl Fn(Transport) -> NewTransport, + ) -> Conn + where + NewTransport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + { Conn { - request_headers, - response_headers, - method, - response_body, - path, - status, - version, - state, - transport: f(transport), - buffer, - request_body_state, - secure, - swansong, - after_send, - start_time, - peer_ip, - http_config, - shared_state, + server_config: self.server_config, + request_headers: self.request_headers, + response_headers: self.response_headers, + method: self.method, + response_body: self.response_body, + path: self.path, + status: self.status, + version: self.version, + state: self.state, + transport: f(self.transport), + buffer: self.buffer, + request_body_state: self.request_body_state, + secure: self.secure, + after_send: self.after_send, + start_time: self.start_time, + peer_ip: self.peer_ip, } } diff --git a/http/src/conn/implementation.rs b/http/src/conn/implementation.rs new file mode 100644 index 0000000000..8ebe482264 --- /dev/null +++ b/http/src/conn/implementation.rs @@ -0,0 +1,404 @@ +use crate::{ + after_send::AfterSend, conn::ReceivedBodyState, copy, util::encoding, BufWriter, Buffer, Conn, + ConnectionStatus, Error, Headers, KnownHeaderName, Method, ReceivedBody, Result, ServerConfig, + Status, TypeSet, Version, SERVER, +}; +use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use memchr::memmem::Finder; +use std::{sync::Arc, time::Instant}; + +impl Conn +where + Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ + pub(crate) async fn send(mut self) -> Result> { + let mut output_buffer = + Vec::with_capacity(self.server_config.http_config.response_buffer_len); + self.write_headers(&mut output_buffer)?; + + let mut bufwriter = BufWriter::new_with_buffer(output_buffer, &mut self.transport); + + if self.method != Method::Head + && !matches!(self.status, Some(Status::NotModified | Status::NoContent)) + { + if let Some(body) = self.response_body.take() { + copy( + body, + &mut bufwriter, + self.server_config.http_config.copy_loops_per_yield, + ) + .await?; + } + } + + bufwriter.flush().await?; + self.after_send.call(true.into()); + self.finish().await + } + + pub(super) fn needs_100_continue(&self) -> bool { + self.request_body_state == ReceivedBodyState::Start + && self.version != Version::Http1_0 + && self + .request_headers + .eq_ignore_ascii_case(KnownHeaderName::Expect, "100-continue") + } + + #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)] + pub(super) fn build_request_body(&mut self) -> ReceivedBody<'_, Transport> { + ReceivedBody::new_with_config( + self.request_content_length().ok().flatten(), + &mut self.buffer, + &mut self.transport, + &mut self.request_body_state, + None, + encoding(&self.request_headers), + &self.server_config.http_config, + ) + } + + fn validate_headers(request_headers: &Headers) -> Result<()> { + let content_length = request_headers.has_header(KnownHeaderName::ContentLength); + let transfer_encoding_chunked = + request_headers.eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked"); + + if content_length && transfer_encoding_chunked { + Err(Error::UnexpectedHeader( + KnownHeaderName::ContentLength.into(), + )) + } else { + Ok(()) + } + } + + // /// # Create a new `Conn` + // /// + // /// This function creates a new conn from the provided + // /// [`Transport`][crate::transport::Transport], as well as any + // /// bytes that have already been read from the transport, and a + // /// [`Swansong`] instance that will be used to signal graceful + // /// shutdown. + // /// + // /// # Errors + // /// + // /// This will return an error variant if: + // /// + // /// * there is an io error when reading from the underlying transport + // /// * headers are too long + // /// * we are unable to parse some aspect of the request + // /// * the request is an unsupported http version + // /// * we cannot make sense of the headers, such as if there is a + // /// `content-length` header as well as a `transfer-encoding: chunked` + // /// header. + // pub async fn new(transport: Transport, bytes: Vec, swansong: Swansong) -> Result { + // Self::new_internal(DEFAULT_CONFIG, transport, bytes.into(), swansong, None).await + // } + + #[cfg(not(feature = "parse"))] + pub(crate) async fn new_internal( + server_config: Arc, + mut transport: Transport, + mut buffer: Buffer, + ) -> Result { + use crate::{HeaderName, HeaderValue}; + use httparse::{Request, EMPTY_HEADER}; + use std::str::FromStr; + + let (head_size, start_time) = + Self::head(&mut transport, &mut buffer, &server_config).await?; + + let mut headers = vec![EMPTY_HEADER; server_config.http_config.max_headers]; + let mut httparse_req = Request::new(&mut headers); + + let status = httparse_req.parse(&buffer[..]).map_err(|e| match e { + httparse::Error::HeaderName => Error::InvalidHeaderName, + httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), + httparse::Error::Status => Error::InvalidStatus, + httparse::Error::TooManyHeaders => Error::HeadersTooLong, + httparse::Error::Version => Error::InvalidVersion, + _ => Error::InvalidHead, + })?; + + if status.is_partial() { + return Err(Error::InvalidHead); + } + + let method = match httparse_req.method { + Some(method) => match method.parse() { + Ok(method) => method, + Err(_) => return Err(Error::UnrecognizedMethod(method.to_string())), + }, + None => return Err(Error::MissingMethod), + }; + + let version = match httparse_req.version { + Some(0) => Version::Http1_0, + Some(1) => Version::Http1_1, + _ => return Err(Error::InvalidVersion), + }; + + let mut request_headers = Headers::new(); + for header in httparse_req.headers { + let header_name = HeaderName::from_str(header.name)?; + let header_value = HeaderValue::from(header.value.to_owned()); + request_headers.append(header_name, header_value); + } + + Self::validate_headers(&request_headers)?; + + let path = httparse_req + .path + .ok_or(Error::RequestPathMissing)? + .to_owned(); + log::trace!("received:\n{method} {path} {version}\n{request_headers}"); + + let mut response_headers = Headers::new(); + response_headers.insert(KnownHeaderName::Server, SERVER); + + buffer.ignore_front(head_size); + + Ok(Self { + transport, + request_headers, + method, + version, + path, + buffer, + response_headers, + status: None, + state: TypeSet::new(), + response_body: None, + request_body_state: ReceivedBodyState::Start, + secure: false, + after_send: AfterSend::default(), + start_time, + peer_ip: None, + server_config, + }) + } + + #[cfg(feature = "parse")] + pub(crate) async fn new_internal( + server_config: Arc, + mut transport: Transport, + mut buffer: Buffer, + ) -> Result { + let (head_size, start_time) = + Self::head(&mut transport, &mut buffer, &server_config).await?; + + let first_line_index = Finder::new(b"\r\n") + .find(&buffer) + .ok_or(Error::InvalidHead)?; + + let mut spaces = memchr::memchr_iter(b' ', &buffer[..first_line_index]); + let first_space = spaces.next().ok_or(Error::MissingMethod)?; + let method = Method::parse(&buffer[0..first_space])?; + let second_space = spaces.next().ok_or(Error::RequestPathMissing)?; + let path = std::str::from_utf8(&buffer[first_space + 1..second_space]) + .map_err(|_| Error::RequestPathMissing)? + .to_string(); + if path.is_empty() { + return Err(Error::InvalidHead); + } + let version = Version::parse(&buffer[second_space + 1..first_line_index])?; + if !matches!(version, Version::Http1_1 | Version::Http1_0) { + return Err(Error::UnsupportedVersion(version)); + } + + let request_headers = Headers::parse(&buffer[first_line_index + 2..head_size])?; + + Self::validate_headers(&request_headers)?; + + let mut response_headers = Headers::new(); + response_headers.insert(KnownHeaderName::Server, SERVER); + + buffer.ignore_front(head_size); + + Ok(Self { + server_config, + transport, + request_headers, + method, + version, + path, + buffer, + response_headers, + status: None, + state: TypeSet::new(), + response_body: None, + request_body_state: ReceivedBodyState::Start, + secure: false, + after_send: AfterSend::default(), + start_time, + peer_ip: None, + }) + } + + pub(super) async fn send_100_continue(&mut self) -> Result<()> { + log::trace!("sending 100-continue"); + Ok(self + .transport + .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") + .await?) + } + + async fn head( + transport: &mut Transport, + buf: &mut Buffer, + server_config: &ServerConfig, + ) -> Result<(usize, Instant)> { + let mut len = 0; + let mut start_with_read = buf.is_empty(); + let mut instant = None; + let finder = Finder::new(b"\r\n\r\n"); + loop { + if len >= server_config.http_config.head_max_len { + return Err(Error::HeadersTooLong); + } + + let bytes = if start_with_read { + buf.expand(); + if len == 0 { + server_config + .swansong + .interrupt(transport.read(buf)) + .await + .ok_or(Error::Closed)?? + } else { + transport.read(&mut buf[len..]).await? + } + } else { + start_with_read = true; + buf.len() + }; + + if instant.is_none() { + instant = Some(Instant::now()); + } + + let search_start = len.max(3) - 3; + let search = finder.find(&buf[search_start..]); + + if let Some(index) = search { + buf.truncate(len + bytes); + return Ok((search_start + index + 4, instant.unwrap())); + } + + len += bytes; + + if bytes == 0 { + return if len == 0 { + Err(Error::Closed) + } else { + Err(Error::InvalidHead) + }; + } + } + } + + async fn next(mut self) -> Result { + if !self.needs_100_continue() || self.request_body_state != ReceivedBodyState::Start { + self.build_request_body().drain().await?; + } + Conn::new_internal(self.server_config, self.transport, self.buffer).await + } + + fn should_close(&self) -> bool { + let request_connection = self.request_headers.get_lower(KnownHeaderName::Connection); + let response_connection = self.response_headers.get_lower(KnownHeaderName::Connection); + + match ( + request_connection.as_deref(), + response_connection.as_deref(), + ) { + (Some("keep-alive"), Some("keep-alive")) => false, + (Some("close"), _) | (_, Some("close")) => true, + _ => self.version == Version::Http1_0, + } + } + + fn should_upgrade(&self) -> bool { + (self.method() == Method::Connect && self.status == Some(Status::Ok)) + || self.status == Some(Status::SwitchingProtocols) + } + + async fn finish(self) -> Result> { + if self.should_close() { + Ok(ConnectionStatus::Close) + } else if self.should_upgrade() { + Ok(ConnectionStatus::Upgrade(self.into())) + } else { + match self.next().await { + Err(Error::Closed) => { + log::trace!("connection closed by client"); + Ok(ConnectionStatus::Close) + } + Err(e) => Err(e), + Ok(conn) => Ok(ConnectionStatus::Conn(conn)), + } + } + } + + fn request_content_length(&self) -> Result> { + if self + .request_headers + .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked") + { + Ok(None) + } else if let Some(cl) = self.request_headers.get_str(KnownHeaderName::ContentLength) { + cl.parse() + .map(Some) + .map_err(|_| Error::InvalidHeaderValue(KnownHeaderName::ContentLength.into())) + } else { + Ok(Some(0)) + } + } + + pub(super) fn body_len(&self) -> Option { + match self.response_body { + Some(ref body) => body.len(), + None => Some(0), + } + } + + fn write_headers(&mut self, output_buffer: &mut Vec) -> Result<()> { + use std::io::Write; + let status = self.status().unwrap_or(Status::NotFound); + + write!( + output_buffer, + "{} {} {}\r\n", + self.version, + status as u16, + status.canonical_reason() + )?; + + self.finalize_headers(); + + log::trace!( + "sending:\n{} {}\n{}", + self.version, + status, + &self.response_headers + ); + + for (name, values) in &self.response_headers { + if name.is_valid() { + for value in values { + if value.is_valid() { + write!(output_buffer, "{name}: ")?; + output_buffer.extend_from_slice(value.as_ref()); + write!(output_buffer, "\r\n")?; + } else { + log::error!("skipping invalid header value {value:?} for header {name}"); + } + } + } else { + log::error!("skipping invalid header with name {name:?}"); + } + } + + write!(output_buffer, "\r\n")?; + Ok(()) + } +} diff --git a/http/src/lib.rs b/http/src/lib.rs index 0cd03dd534..44620061c8 100644 --- a/http/src/lib.rs +++ b/http/src/lib.rs @@ -25,47 +25,53 @@ capabilities. Please note that trillium itself provides a much more usable interface on top of `trillium_http`, at very little cost. ``` -# fn main() -> trillium_http::Result<()> { smol::block_on(async { -use async_net::{TcpListener, TcpStream}; -use futures_lite::StreamExt; -use trillium_http::{Conn, Result, Swansong}; - -let swansong = Swansong::new(); -let listener = TcpListener::bind(("localhost", 0)).await?; -let port = listener.local_addr()?.port(); - -let server_swansong = swansong.clone(); -let server_handle = smol::spawn(async move { - let mut incoming = server_swansong.interrupt(listener.incoming()); - - while let Some(Ok(stream)) = incoming.next().await { - let swansong = server_swansong.clone(); - smol::spawn(Conn::map(stream, swansong, |mut conn: Conn| async move { - conn.set_response_body("hello world"); - conn.set_status(200); - conn - })).detach() - } - - Result::Ok(()) -}); - -// this example uses the trillium client -// any other http client would work here too -let url = format!("http://localhost:{}/", port); -let client = trillium_client::Client::new(trillium_smol::ClientConfig::default()); -let mut client_conn = client.get(&*url).await?; - -assert_eq!(client_conn.status().unwrap(), 200); -assert_eq!(client_conn.response_headers().get_str("content-length"), Some("11")); -assert_eq!( - client_conn.response_body().read_string().await?, - "hello world" -); - -swansong.shut_down(); // stop the server after one request -server_handle.await?; // wait for the server to shut down -# Result::Ok(()) }) } +fn main() -> trillium_http::Result<()> { + smol::block_on(async { + use async_net::TcpListener; + use futures_lite::StreamExt; + use std::sync::Arc; + use trillium_http::ServerConfig; + + let server_config = Arc::new(ServerConfig::default()); + let listener = TcpListener::bind(("localhost", 0)).await?; + let local_addr = listener.local_addr().unwrap(); + let server_handle = smol::spawn({ + let server_config = server_config.clone(); + async move { + let mut incoming = server_config.swansong().interrupt(listener.incoming()); + + while let Some(Ok(stream)) = incoming.next().await { + smol::spawn(server_config.clone().run(stream, |mut conn| async move { + conn.set_response_body("hello world"); + conn.set_status(200); + conn + })) + .detach() + } + } + }); + + // this example uses the trillium client + // any other http client would work here too + let client = trillium_client::Client::new(trillium_smol::ClientConfig::default()) + .with_base(local_addr); + let mut client_conn = client.get("/").await?; + + assert_eq!(client_conn.status().unwrap(), 200); + assert_eq!( + client_conn.response_headers().get_str("content-length"), + Some("11") + ); + assert_eq!( + client_conn.response_body().read_string().await?, + "hello world" + ); + + server_config.shut_down().await; // stop the server after one request + server_handle.await; // wait for the server to shut down + Ok(()) + }) +} ``` */ @@ -156,3 +162,5 @@ pub use copy::copy; pub(crate) use copy::copy; mod liveness; +mod server_config; +pub use server_config::ServerConfig; diff --git a/http/src/received_body.rs b/http/src/received_body.rs index bf315d6aac..c9b00cec50 100644 --- a/http/src/received_body.rs +++ b/http/src/received_body.rs @@ -369,7 +369,7 @@ impl<'conn, Transport> Debug for ReceivedBody<'conn, Transport> { f.debug_struct("RequestBody") .field("state", &*self.state) .field("content_length", &self.content_length) - .field("buffer", &"..") + .field("buffer", &format_args!("..")) .field("on_completion", &self.on_completion.is_some()) .finish() } diff --git a/http/src/received_body/chunked.rs b/http/src/received_body/chunked.rs index 02c70eea83..545bd31770 100644 --- a/http/src/received_body/chunked.rs +++ b/http/src/received_body/chunked.rs @@ -172,7 +172,8 @@ mod tests { use crate::{http_config::DEFAULT_CONFIG, Buffer, HttpConfig}; use encoding_rs::UTF_8; use futures_lite::{io::Cursor, AsyncRead, AsyncReadExt}; - use trillium_testing::block_on; + use test_harness::test; + use trillium_testing::harness; #[track_caller] fn assert_decoded( @@ -245,22 +246,20 @@ mod tests { decode_with_config(input, poll_size, &DEFAULT_CONFIG).await } - #[test] - fn test_full_decode() { - block_on(async { - for size in 1..50 { - let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n"; - let output = decode(input.into(), size).await.unwrap(); - assert_eq!(output, "12345abcdef", "size: {size}"); - - let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n"; - let output = decode(input.into(), size).await.unwrap(); - assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); - - assert!(decode(String::new(), size).await.is_err()); - assert!(decode("fffffffffffffff0\r\n".into(), size).await.is_err()); - } - }); + #[test(harness)] + async fn test_full_decode() { + for size in 1..50 { + let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n"; + let output = decode(input.into(), size).await.unwrap(); + assert_eq!(output, "12345abcdef", "size: {size}"); + + let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n"; + let output = decode(input.into(), size).await.unwrap(); + assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); + + assert!(decode(String::new(), size).await.is_err()); + assert!(decode("fffffffffffffff0\r\n".into(), size).await.is_err()); + } } async fn build_chunked_body(input: String) -> String { @@ -277,47 +276,44 @@ mod tests { String::from_utf8(output).unwrap() } - #[test] - fn test_read_buffer_short() { - block_on(async { - let input = "test ".repeat(50); - let chunked = build_chunked_body(input.clone()).await; - - for size in 1..10 { - assert_eq!( - &decode(chunked.clone(), size).await.unwrap(), - &input, - "size: {size}" - ); - } - }); + #[test(harness)] + async fn test_read_buffer_short() { + let input = "test ".repeat(50); + let chunked = build_chunked_body(input.clone()).await; + + for size in 1..10 { + assert_eq!( + &decode(chunked.clone(), size).await.unwrap(), + &input, + "size: {size}" + ); + } } - #[test] - fn test_max_len() { - block_on(async { - let input = build_chunked_body("test ".repeat(10)).await; - - for size in 4..10 { - assert!(decode_with_config( - input.clone(), - size, - &HttpConfig::default().with_received_body_max_len(5) - ) - .await - .is_err()); + #[test(harness)] + async fn test_max_len() { + let input = build_chunked_body("test ".repeat(10)).await; - assert!( - decode_with_config(input.clone(), size, &HttpConfig::default()) - .await - .is_ok() - ); - } - }); + for size in 4..10 { + assert!(decode_with_config( + input.clone(), + size, + &HttpConfig::default().with_received_body_max_len(5) + ) + .await + .is_err()); + + assert!( + decode_with_config(input.clone(), size, &HttpConfig::default()) + .await + .is_ok() + ); + } } #[test] fn test_chunk_start() { + let _ = env_logger::builder().is_test(true).try_init(); assert_decoded((0, "5\r\n12345\r\n"), (Some(0), "12345", "")); assert_decoded((0, "F\r\n1"), (Some(14 + 2), "1", "")); assert_decoded((0, "5\r\n123"), (Some(2 + 2), "123", "")); @@ -342,6 +338,8 @@ mod tests { #[test] fn test_chunk_start_with_ext() { + let _ = env_logger::builder().is_test(true).try_init(); + assert_decoded((0, "5;abcdefg\r\n12345\r\n"), (Some(0), "12345", "")); assert_decoded((0, "F;aaa\taaaaa\taaa aaa\r\n1"), (Some(14 + 2), "1", "")); assert_decoded((0, "5;;;;;;;;;;;;;;;;\r\n123"), (Some(2 + 2), "123", "")); @@ -367,55 +365,53 @@ mod tests { assert_decoded((7, "hello\r\n0;\r\n\r\n"), (None, "hello", "")); } - #[test] - fn read_string_and_read_bytes() { - block_on(async { - let content = build_chunked_body("test ".repeat(100)).await; - assert_eq!( - new_with_config(content.clone(), &DEFAULT_CONFIG) - .read_string() - .await - .unwrap() - .len(), - 500 - ); + #[test(harness)] + async fn read_string_and_read_bytes() { + let content = build_chunked_body("test ".repeat(100)).await; + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .read_string() + .await + .unwrap() + .len(), + 500 + ); - assert_eq!( - new_with_config(content.clone(), &DEFAULT_CONFIG) - .read_bytes() - .await - .unwrap() - .len(), - 500 - ); + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .read_bytes() + .await + .unwrap() + .len(), + 500 + ); - assert!(new_with_config( - content.clone(), - &DEFAULT_CONFIG.with_received_body_max_len(400) - ) - .read_string() - .await - .is_err()); + assert!(new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(400) + ) + .read_string() + .await + .is_err()); - assert!(new_with_config( - content.clone(), - &DEFAULT_CONFIG.with_received_body_max_len(400) - ) + assert!(new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(400) + ) + .read_bytes() + .await + .is_err()); + + assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(400) .read_bytes() .await .is_err()); - assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) - .with_max_len(400) - .read_bytes() - .await - .is_err()); - - assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) - .with_max_len(400) - .read_string() - .await - .is_err()); - }); + assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(400) + .read_string() + .await + .is_err()); } } diff --git a/http/src/server_config.rs b/http/src/server_config.rs new file mode 100644 index 0000000000..744a4421f9 --- /dev/null +++ b/http/src/server_config.rs @@ -0,0 +1,117 @@ +use crate::{Conn, ConnectionStatus, HttpConfig, Result, TypeSet, Upgrade}; +use futures_lite::{AsyncRead, AsyncWrite}; +use std::{future::Future, sync::Arc}; +use swansong::{ShutdownCompletion, Swansong}; +/// This struct represents the shared configuration and context for a http server. +/// +/// This currently contains tunable parameters in a [`HttpConfig`], the [`Swansong`] graceful +/// shutdown control interface, and a shared [`TypeSet`] that contains application-specific +/// information about the running server +#[derive(Default, Debug)] +pub struct ServerConfig { + pub(crate) http_config: HttpConfig, + pub(crate) swansong: Swansong, + pub(crate) shared_state: TypeSet, +} +impl AsRef for ServerConfig { + fn as_ref(&self) -> &TypeSet { + &self.shared_state + } +} + +impl AsMut for ServerConfig { + fn as_mut(&mut self) -> &mut TypeSet { + &mut self.shared_state + } +} + +impl AsRef for ServerConfig { + fn as_ref(&self) -> &Swansong { + &self.swansong + } +} + +impl AsRef for ServerConfig { + fn as_ref(&self) -> &HttpConfig { + &self.http_config + } +} + +impl ServerConfig { + /// Modify the [`HttpConfig`] for this server. + pub fn http_config_mut(&mut self) -> &mut HttpConfig { + &mut self.http_config + } + + /// Replace the [`Swansong`] graceful shutdown control interface for this server. + pub fn set_swansong(&mut self, swansong: Swansong) { + self.swansong = swansong; + } + + /// Borrow the [`Swansong`] graceful shutdown control interface for this server. + pub fn swansong(&self) -> &Swansong { + &self.swansong + } + + /// Construct a new `ServerConfig` + pub fn new() -> Self { + Self::default() + } + + /// Borrow the shared state [`TypeSet`] for this server + pub fn shared_state(&self) -> &TypeSet { + &self.shared_state + } + + /// Mutate the shared state [`TypeSet`] for this server. + /// + /// Types added here will be immutably available on all [`Conn`]s handled by this server. + pub fn shared_state_mut(&mut self) -> &mut TypeSet { + &mut self.shared_state + } + + /// Perform HTTP on the provided transport, applying the provided `async Conn -> Conn` handler + /// function for every distinct http request-response. + /// + /// For any given invocation of `ServerConfig::run`, the handler function may run any number of times, + /// depending on whether the connection is reused by the client. + /// + /// This can only be called on an `Arc` because an arc clone is moved into the Conn. + /// + /// # Errors + /// + /// This function will return an [`Error`] if any of the http requests is irrecoverably + /// malformed or otherwise noncompliant. + pub async fn run( + self: Arc, + transport: Transport, + mut handler: Handler, + ) -> Result>> + where + Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + Handler: FnMut(Conn) -> Fut, + Fut: Future>, + { + let _guard = self.swansong.guard(); + let buffer = Vec::with_capacity(self.http_config.request_buffer_initial_len).into(); + + let mut conn = Conn::new_internal(self, transport, buffer).await?; + + loop { + conn = match handler(conn).await.send().await? { + ConnectionStatus::Upgrade(upgrade) => return Ok(Some(upgrade)), + ConnectionStatus::Close => return Ok(None), + ConnectionStatus::Conn(next) => next, + } + } + } + + /// Attempt graceful shutdown of this server. + /// + /// The returned [`ShutdownCompletion`] type can + /// either be awaited in an async context or blocked on with [`ShutdownCompletion::block`] in a + /// blocking context + pub fn shut_down(&self) -> ShutdownCompletion { + self.swansong.shut_down() + } +} diff --git a/http/src/state_set/entry.rs b/http/src/state_set/entry.rs new file mode 100644 index 0000000000..c1f8f17efe --- /dev/null +++ b/http/src/state_set/entry.rs @@ -0,0 +1,93 @@ +use super::Value; +use std::{any::TypeId, collections::btree_map, marker::PhantomData}; + +pub enum Entry<'a, T> { + Vacant(VacantEntry<'a, T>), + Occupied(OccupiedEntry<'a, T>), +} + +pub struct VacantEntry<'a, T>( + pub(super) btree_map::VacantEntry<'a, TypeId, Value>, + PhantomData, +); +pub struct OccupiedEntry<'a, T>( + pub(super) btree_map::OccupiedEntry<'a, TypeId, Value>, + PhantomData, +); + +impl<'a, T: Send + Sync + 'static> Entry<'a, T> { + pub fn or_insert(self, default: T) -> &'a mut T { + match self { + Entry::Vacant(vacant) => vacant.insert(default), + Entry::Occupied(occupied) => occupied.into_mut(), + } + } + + pub fn or_insert_with(self, default: impl FnOnce() -> T) -> &'a mut T { + match self { + Entry::Vacant(vacant) => vacant.insert(default()), + Entry::Occupied(occupied) => occupied.into_mut(), + } + } + + pub fn and_modify(self, f: impl FnOnce(&mut T)) -> Self { + match self { + Entry::Vacant(vacant) => Entry::Vacant(vacant), + Entry::Occupied(mut occupied) => { + f(occupied.get_mut()); + Entry::Occupied(occupied) + } + } + } + + pub fn take(self) -> Option { + match self { + Entry::Vacant(_) => None, + Entry::Occupied(occupied) => Some(occupied.remove()), + } + } + + pub(super) fn new(entry: btree_map::Entry<'a, TypeId, Value>) -> Self { + match entry { + btree_map::Entry::Vacant(vacant) => Self::Vacant(VacantEntry(vacant, PhantomData)), + btree_map::Entry::Occupied(occupied) => { + Self::Occupied(OccupiedEntry(occupied, PhantomData)) + } + } + } +} + +impl<'a, T: Default + Send + Sync + 'static> Entry<'a, T> { + pub fn or_default(self) -> &'a mut T { + #[allow(clippy::unwrap_or_default)] // this is the implementation of or_default + self.or_insert_with(T::default) + } +} + +impl<'a, T: Send + Sync + 'static> VacantEntry<'a, T> { + pub fn insert(self, value: T) -> &'a mut T { + self.0.insert(Box::new(value)).downcast_mut().unwrap() + } +} + +impl<'a, T: Send + Sync + 'static> OccupiedEntry<'a, T> { + pub fn get(&self) -> &T { + self.0.get().downcast_ref().unwrap() + } + + pub fn get_mut(&mut self) -> &mut T { + self.0.get_mut().downcast_mut().unwrap() + } + + pub fn insert(&mut self, value: T) -> T { + *self.0.insert(Box::new(value)).downcast().unwrap() + } + + pub fn remove(self) -> T { + *self.0.remove().downcast().unwrap() + } + + pub fn into_mut(self) -> &'a mut T { + self.0.into_mut().downcast_mut().unwrap() + } +} diff --git a/http/src/synthetic.rs b/http/src/synthetic.rs index 68d7432699..03cefee426 100644 --- a/http/src/synthetic.rs +++ b/http/src/synthetic.rs @@ -1,10 +1,11 @@ use crate::{ after_send::AfterSend, http_config::DEFAULT_CONFIG, received_body::ReceivedBodyState, - transport::Transport, Conn, Headers, KnownHeaderName, Method, Swansong, TypeSet, Version, + transport::Transport, Conn, Headers, KnownHeaderName, Method, ServerConfig, TypeSet, Version, }; use futures_lite::io::{AsyncRead, AsyncWrite, Cursor, Result}; use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, time::Instant, }; @@ -138,6 +139,7 @@ impl Conn { request_headers.insert(KnownHeaderName::ContentLength, transport.len().to_string()); Self { + server_config: Arc::default(), transport, request_headers, response_headers: Headers::new(), @@ -150,15 +152,24 @@ impl Conn { buffer: Vec::with_capacity(DEFAULT_CONFIG.request_buffer_initial_len).into(), request_body_state: ReceivedBodyState::Start, secure: false, - swansong: Swansong::new(), after_send: AfterSend::default(), start_time: Instant::now(), peer_ip: None, - http_config: DEFAULT_CONFIG, - shared_state: None, } } + /// use a particular shared server config for this synthetic conn + pub fn set_server_config(&mut self, server_config: Arc) { + self.server_config = server_config; + } + + /// chainable setter for server config + #[must_use] + pub fn with_server_config(mut self, server_config: Arc) -> Self { + self.set_server_config(server_config); + self + } + /// simulate closing the transport pub fn close(&mut self) { self.transport.close(); diff --git a/http/src/upgrade.rs b/http/src/upgrade.rs index e2fb21d3a6..dafb3b7f4b 100644 --- a/http/src/upgrade.rs +++ b/http/src/upgrade.rs @@ -1,4 +1,4 @@ -use crate::{received_body::read_buffered, Buffer, Conn, Headers, Method, Swansong, TypeSet}; +use crate::{received_body::read_buffered, Buffer, Conn, Headers, Method, ServerConfig, TypeSet}; use futures_lite::{AsyncRead, AsyncWrite}; use std::{ fmt::{self, Debug, Formatter}, @@ -6,6 +6,7 @@ use std::{ net::IpAddr, pin::Pin, str, + sync::Arc, task::{Context, Poll}, }; use trillium_macros::AsyncWrite; @@ -39,10 +40,8 @@ pub struct Upgrade { /// already. It is your responsibility to process these bytes /// before reading directly from the transport. pub buffer: Buffer, - /// A [`Swansong`] which can and should be used to gracefully shut - /// down any long running streams or futures associated with this - /// upgrade - pub swansong: Swansong, + /// The [`ServerConfig`] shared for this server + pub server_config: Arc, /// the ip address of the connection, if available pub peer_ip: Option, } @@ -63,7 +62,7 @@ impl Upgrade { transport, buffer, state: TypeSet::new(), - swansong: Swansong::new(), + server_config: Arc::default(), peer_ip: None, } } @@ -114,7 +113,7 @@ impl Upgrade { state: self.state, buffer: self.buffer, request_headers: self.request_headers, - swansong: self.swansong, + server_config: self.server_config, peer_ip: self.peer_ip, } } @@ -127,9 +126,9 @@ impl Debug for Upgrade { .field("path", &self.path) .field("method", &self.method) .field("buffer", &self.buffer) - .field("swansong", &self.swansong) + .field("server_config", &self.server_config) .field("state", &self.state) - .field("transport", &"..") + .field("transport", &format_args!("..")) .field("peer_ip", &self.peer_ip) .finish() } @@ -144,7 +143,7 @@ impl From> for Upgrade { state, transport, buffer, - swansong, + server_config, peer_ip, .. } = conn; @@ -156,7 +155,7 @@ impl From> for Upgrade { state, transport, buffer, - swansong, + server_config, peer_ip, } } diff --git a/http/tests/corpus.rs b/http/tests/corpus.rs index f657bc1469..e3bbe68e62 100644 --- a/http/tests/corpus.rs +++ b/http/tests/corpus.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use indoc::formatdoc; use pretty_assertions::assert_str_eq; use std::{env, net::Shutdown, path::PathBuf}; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, Swansong}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig, Swansong}; use trillium_testing::{harness, RuntimeTrait, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -44,7 +46,6 @@ async fn handler(mut conn: Conn) -> Conn { #[test(harness)] async fn corpus_test() { - env_logger::init(); let runtime = trillium_testing::runtime(); let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/corpus"); let filter = env::var("CORPUS_TEST_FILTER").unwrap_or_default(); @@ -69,9 +70,10 @@ async fn corpus_test() { let (client, server) = TestTransport::new(); let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); let res = runtime.spawn({ - let swansong = swansong.clone(); - async move { Conn::map(server, swansong, handler).await } + let server_config = server_config.clone(); + async move { server_config.run(server, handler).await } }); client.write_all(request); diff --git a/http/tests/one_hundred_continue.rs b/http/tests/one_hundred_continue.rs index 87f8bb869f..b49baa90ef 100644 --- a/http/tests/one_hundred_continue.rs +++ b/http/tests/one_hundred_continue.rs @@ -1,7 +1,9 @@ +use std::sync::Arc; + use indoc::{formatdoc, indoc}; use pretty_assertions::assert_eq; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, Swansong, SERVER}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig, SERVER}; use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -21,7 +23,8 @@ async fn handler(mut conn: Conn) -> Conn { async fn one_hundred_continue() -> TestResult { let (client, server) = TestTransport::new(); let runtime = trillium_testing::runtime(); - let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); + let server_config = Arc::new(ServerConfig::default()); + let handle = runtime.spawn(server_config.run(server, handler)); client.write_all(indoc! {" POST / HTTP/1.1\r @@ -57,7 +60,8 @@ async fn one_hundred_continue() -> TestResult { async fn one_hundred_continue_http_one_dot_zero() -> TestResult { let (client, server) = TestTransport::new(); let runtime = trillium_testing::runtime(); - let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); + let server_config = Arc::new(ServerConfig::default()); + let handle = runtime.spawn(server_config.run(server, handler)); client.write_all(indoc! { " POST / HTTP/1.0\r diff --git a/http/tests/unsafe_headers.rs b/http/tests/unsafe_headers.rs index b4c3187e5f..272f8910c8 100644 --- a/http/tests/unsafe_headers.rs +++ b/http/tests/unsafe_headers.rs @@ -1,8 +1,9 @@ +use std::sync::Arc; + use indoc::{formatdoc, indoc}; use pretty_assertions::assert_eq; -use swansong::Swansong; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, SERVER}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig, SERVER}; use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -24,7 +25,8 @@ async fn handler(mut conn: Conn) -> Conn { async fn bad_headers() -> TestResult { let (client, server) = TestTransport::new(); let runtime = trillium_testing::runtime(); - let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); + let server_config = Arc::new(ServerConfig::default()); + let handle = runtime.spawn(server_config.run(server, handler)); client.write_all(indoc! {" GET / HTTP/1.1\r diff --git a/http/tests/use_cases.rs b/http/tests/use_cases.rs index 8ef194ed3d..c60174f23e 100644 --- a/http/tests/use_cases.rs +++ b/http/tests/use_cases.rs @@ -3,7 +3,7 @@ use std::{future::Future, marker::PhantomData, sync::Arc}; use test_harness::test; use trillium_client::{Client, Connector, Url}; -use trillium_http::{Conn, KnownHeaderName}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig}; use trillium_testing::{harness, Runtime, TestResult, TestTransport}; #[test(harness)] @@ -22,6 +22,7 @@ pub struct ServerConnector { handler: Arc, fut: PhantomData, runtime: Runtime, + server_config: Arc, } impl ServerConnector @@ -33,6 +34,7 @@ where Self { handler: Arc::new(handler), fut: PhantomData, + server_config: ServerConfig::default().into(), runtime: trillium_testing::runtime().into(), } } @@ -50,12 +52,10 @@ where let (client_transport, server_transport) = TestTransport::new(); let handler = self.handler.clone(); + let server_config = self.server_config.clone(); - self.runtime.spawn(async move { - Conn::map(server_transport, Default::default(), &*handler) - .await - .unwrap(); - }); + self.runtime + .spawn(async move { server_config.run(server_transport, &*handler).await }); Ok(client_transport) } diff --git a/logger/Cargo.toml b/logger/Cargo.toml index 586f9eefb6..a11c98f999 100644 --- a/logger/Cargo.toml +++ b/logger/Cargo.toml @@ -16,6 +16,7 @@ log = "0.4.20" size = "0.4.1" time = { version = "0.3.31", features = ["local-offset", "formatting", "macros"] } trillium = { path = "../trillium", version = "0.2.20" } +url = "2.5.0" [dev-dependencies] access_log_parser = "0.8.0" diff --git a/logger/src/lib.rs b/logger/src/lib.rs index cac361555c..4be04508a9 100644 --- a/logger/src/lib.rs +++ b/logger/src/lib.rs @@ -10,7 +10,11 @@ Welcome to the trillium logger! */ pub use crate::formatters::{apache_combined, apache_common, dev_formatter}; -use std::{fmt::Display, io::IsTerminal, sync::Arc}; +use std::{ + fmt::{Display, Write}, + io::IsTerminal, + sync::Arc, +}; use trillium::{Conn, Handler, Info}; /** Components with which common log formats can be constructed @@ -277,19 +281,25 @@ where F: LogFormatter, { async fn init(&mut self, info: &mut Info) { - self.target.write(format!( - " -🌱🦀🌱 {} started -Listening at {}{} - -Control-C to quit", - info.server_description(), - info.listener_description(), - info.tcp_socket_addr() - .map(|s| format!(" (bound as tcp://{s})")) - .unwrap_or_default(), - )); + let mut string = "\n🌱🦀🌱 trillium started\n".to_string(); + + if let Some(url) = info.state::() { + writeln!(string, "Listening at {}", url.as_str()).unwrap(); + } + + if let Some(tcp) = info.tcp_socket_addr() { + writeln!(string, "Bound as tcp://{tcp}").unwrap(); + } + + if let Some(unix) = info.unix_socket_addr().and_then(|unix| unix.as_pathname()) { + writeln!(string, "Bound as unix://{}", unix.display()).unwrap(); + } + + writeln!(string, "Control-C to quit").unwrap(); + + self.target.write(string); } + async fn run(&self, conn: Conn) -> Conn { conn.with_state(LoggerWasRun) } diff --git a/macros/tests/derive.rs b/macros/tests/derive.rs index 7063cf844f..aa95fed3c0 100644 --- a/macros/tests/derive.rs +++ b/macros/tests/derive.rs @@ -20,7 +20,7 @@ fn full_lifecycle() { async fn init(&mut self, info: &mut Info) { self.init = true; - *info.server_description_mut() = "inner handler took over".into(); + info.insert_state("inner handler took over"); } async fn before_send(&self, conn: Conn) -> Conn { @@ -40,7 +40,7 @@ fn full_lifecycle() { let mut handler = OuterHandler(InnerHandler { init: false }); handler.init(&mut info).await; - assert_eq!(info.server_description(), "inner handler took over"); + assert_eq!(info.state::<&str>().unwrap(), &"inner handler took over"); assert!(handler.0.init); assert_ok!(get("/").run_async(&handler).await, "run", "before-send" => "before-send"); assert_eq!(handler.name(), "OuterHandler (inner handler)"); diff --git a/native-tls/src/client.rs b/native-tls/src/client.rs index 66ecfb03dc..e923792876 100644 --- a/native-tls/src/client.rs +++ b/native-tls/src/client.rs @@ -47,7 +47,7 @@ impl Debug for NativeTlsConfig { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("NativeTlsConfig") .field("tcp_config", &self.tcp_config) - .field("tls_connector", &"..") + .field("tls_connector", &format_args!("..")) .finish() } } diff --git a/rustls/src/client.rs b/rustls/src/client.rs index 788d038ad5..8a3228070d 100644 --- a/rustls/src/client.rs +++ b/rustls/src/client.rs @@ -99,7 +99,7 @@ impl RustlsConfig { impl Debug for RustlsConfig { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("RustlsConfig") - .field("rustls_config", &"..") + .field("rustls_config", &format_args!("..")) .field("tcp_config", &self.tcp_config) .finish() } diff --git a/server-common/Cargo.toml b/server-common/Cargo.toml index d97e0154de..515cf068a3 100644 --- a/server-common/Cargo.toml +++ b/server-common/Cargo.toml @@ -14,6 +14,7 @@ categories = ["web-programming::http-server", "web-programming"] async-channel = "2.2.0" async_cell = "0.2.2" futures-lite = "2.1.0" +listenfd = "1.0.1" log = "0.4.20" pin-project-lite = "0.2.13" swansong = "0.3.0" diff --git a/server-common/src/acceptor.rs b/server-common/src/acceptor.rs index b68a6034d1..34de415aa8 100644 --- a/server-common/src/acceptor.rs +++ b/server-common/src/acceptor.rs @@ -31,6 +31,11 @@ where &self, input: Input, ) -> impl Future> + Send; + + /// should conns be treated as secure? + fn is_secure(&self) -> bool { + true + } } impl Acceptor for () @@ -42,4 +47,8 @@ where async fn accept(&self, input: Input) -> Result { Ok(input) } + + fn is_secure(&self) -> bool { + false + } } diff --git a/server-common/src/config.rs b/server-common/src/config.rs index 4de893e268..4deec9f816 100644 --- a/server-common/src/config.rs +++ b/server-common/src/config.rs @@ -1,12 +1,11 @@ -use crate::{Acceptor, RuntimeTrait, Server, ServerHandle}; +use crate::{running_config::RunningConfig, Acceptor, RuntimeTrait, Server, ServerHandle}; use async_cell::sync::AsyncCell; -use std::{ - cell::OnceCell, - marker::PhantomData, - net::SocketAddr, - sync::{Arc, RwLock}, -}; -use trillium::{Handler, HttpConfig, Info, Swansong}; +use futures_lite::StreamExt; +use url::Url; + +use std::{cell::OnceCell, net::SocketAddr, pin::pin, sync::Arc}; +use trillium::{Handler, HttpConfig, Info, Swansong, TypeSet}; +use trillium_http::ServerConfig; /** # Primary entrypoint for configuring and running a trillium server @@ -60,17 +59,15 @@ In order to use this to _implement_ a trillium server, see #[derive(Debug)] pub struct Config { pub(crate) acceptor: AcceptorType, - pub(crate) port: Option, + pub(crate) binding: Option, pub(crate) host: Option, + pub(crate) server_config_cell: Arc>>, + pub(crate) max_connections: Option, pub(crate) nodelay: bool, - pub(crate) swansong: Swansong, + pub(crate) port: Option, pub(crate) register_signals: bool, - pub(crate) max_connections: Option, - pub(crate) info: Arc>>, - pub(crate) binding: RwLock>, - pub(crate) server: PhantomData, - pub(crate) http_config: HttpConfig, pub(crate) runtime: ServerType::Runtime, + pub(crate) server_config: ServerConfig, } impl Config @@ -84,8 +81,8 @@ where /// outside of trillium's web server. For applications that embed a /// trillium server inside of an already-running async runtime, use /// [`Config::run_async`] - pub fn run(self, h: H) { - ServerType::run(self, h) + pub fn run(self, handler: impl Handler) { + self.runtime.clone().block_on(self.run_async(handler)); } /// Runs the provided handler with this config, in an @@ -93,10 +90,74 @@ where /// for an application that needs to spawn async tasks that are /// unrelated to the trillium application. If you do not need to spawn /// other tasks, [`Config::run`] is the preferred entrypoint - pub async fn run_async(self, handler: impl Handler) { - let swansong = self.swansong.clone(); - ServerType::run_async(self, handler).await; - swansong.shut_down().await; + pub async fn run_async(self, mut handler: impl Handler) { + let Self { + runtime, + acceptor, + max_connections, + nodelay, + binding, + host, + port, + register_signals, + server_config, + server_config_cell, + } = self; + let host = host + .or_else(|| std::env::var("HOST").ok()) + .unwrap_or_else(|| "localhost".into()); + let port = port + .or_else(|| { + std::env::var("PORT") + .ok() + .map(|x| x.parse().expect("PORT must be an unsigned integer")) + }) + .unwrap_or(8080); + + let listener = binding + .inspect(|_| log::debug!("taking prebound listener")) + .unwrap_or_else(|| ServerType::from_host_and_port(&host, port)); + + let swansong = server_config.swansong().clone(); + + if register_signals { + runtime.spawn({ + let runtime = runtime.clone(); + let swansong = swansong.clone(); + async move { + let mut signals = pin!(runtime.hook_signals([2, 3, 15])); + while signals.next().await.is_some() { + if swansong.state().is_shutting_down() { + eprintln!("\nSecond interrupt, shutting down harshly"); + std::process::exit(1); + } else { + println!("\nShutting down gracefully.\nControl-C again to force."); + swansong.shut_down(); + } + } + } + }); + } + + let mut info = Info::from(server_config) + .with_state(runtime.clone().into()) + .with_state(runtime.clone()); + listener.init(&mut info); + insert_url(info.as_mut(), acceptor.is_secure()); + handler.init(&mut info).await; + + let server_config = Arc::new(ServerConfig::from(info)); + server_config_cell.set(server_config.clone()); + + let running_config = Arc::new(RunningConfig { + acceptor, + max_connections, + server_config, + runtime, + nodelay, + }); + + running_config.run_async(listener, handler).await; } /// Spawns the server onto the async runtime, returning a @@ -113,9 +174,9 @@ where /// when spawning the server onto a runtime. pub fn handle(&self) -> ServerHandle { ServerHandle { - swansong: self.swansong.clone(), - info: self.info.clone(), - received_info: OnceCell::new(), + swansong: self.server_config.swansong().clone(), + server_config: self.server_config_cell.clone(), + received_server_config: OnceCell::new(), runtime: self.runtime().into(), } } @@ -176,20 +237,18 @@ where host: self.host, port: self.port, nodelay: self.nodelay, - server: PhantomData, - swansong: self.swansong, register_signals: self.register_signals, max_connections: self.max_connections, - info: self.info, + server_config_cell: self.server_config_cell, + server_config: self.server_config, binding: self.binding, - http_config: self.http_config, runtime: self.runtime, } } /// use the specific [`Swansong`] provided pub fn with_swansong(mut self, swansong: Swansong) -> Self { - self.swansong = swansong; + self.server_config.set_swansong(swansong); self } @@ -207,7 +266,7 @@ where /// /// See [`HttpConfig`] for documentation pub fn with_http_config(mut self, http_config: HttpConfig) -> Self { - self.http_config = http_config; + *self.server_config.http_config_mut() = http_config; self } @@ -231,21 +290,28 @@ where eprintln!("constructing a config with both a port and a pre-bound listener will ignore the port. this may be a panic in the future"); } - self.binding = RwLock::new(Some(server.into())); + self.binding = Some(server.into()); self } fn has_binding(&self) -> bool { - self.binding - .read() - .as_deref() - .map_or(false, Option::is_some) + self.binding.is_some() } /// retrieve the runtime pub fn runtime(&self) -> ServerType::Runtime { self.runtime.clone() } + + /// return the configured port + pub fn port(&self) -> Option { + self.port + } + + /// return the configured host + pub fn host(&self) -> Option<&str> { + self.host.as_deref() + } } impl Config { @@ -274,15 +340,28 @@ impl Default for Config { acceptor: (), port: None, host: None, - server: PhantomData, nodelay: false, - swansong: Swansong::new(), register_signals: cfg!(unix), max_connections, - info: AsyncCell::shared(), - binding: RwLock::new(None), - http_config: HttpConfig::default(), + server_config_cell: AsyncCell::shared(), + binding: None, runtime: ServerType::runtime(), + server_config: Default::default(), } } } + +fn insert_url(state: &mut TypeSet, secure: bool) -> Option<()> { + let socket_addr = state.get::().copied()?; + let vacant_entry = state.entry::().into_vacant()?; + let scheme = if secure { "https" } else { "http" }; + let url = Url::parse(&if socket_addr.ip().is_loopback() { + format!("{scheme}://localhost:{}/", socket_addr.port()) + } else { + format!("{scheme}://{socket_addr}/") + }) + .ok()?; + + vacant_entry.insert(url); + Some(()) +} diff --git a/server-common/src/config_ext.rs b/server-common/src/config_ext.rs deleted file mode 100644 index 491cdeadd4..0000000000 --- a/server-common/src/config_ext.rs +++ /dev/null @@ -1,232 +0,0 @@ -use crate::{Acceptor, Config, Server, Transport}; -use futures_lite::prelude::*; -use std::{ - io::ErrorKind, - net::{SocketAddr, TcpListener, ToSocketAddrs}, - sync::Arc, -}; -use trillium::Handler; -use trillium_http::{transport::BoxedTransport, Error, Swansong, SERVICE_UNAVAILABLE}; -/// # Server-implementer interfaces to Config -/// -/// These functions are intended for use by authors of trillium servers, -/// and should not be necessary to build an application. Please open -/// an issue if you find yourself using this trait directly in an -/// application. - -pub trait ConfigExt -where - ServerType: Server, -{ - /// resolve a port for this application, either directly - /// configured, from the environmental variable `PORT`, or a default - /// of `8080` - fn port(&self) -> u16; - - /// resolve the host for this application, either directly from - /// configuration, from the `HOST` env var, or `"localhost"` - fn host(&self) -> String; - - /// use the [`ConfigExt::port`] and [`ConfigExt::host`] to resolve - /// a vec of potential socket addrs - fn socket_addrs(&self) -> Vec; - - /// returns whether this server should register itself for - /// operating system signals. this flag does nothing aside from - /// communicating to the server implementer that this is - /// desired. defaults to true on `cfg(unix)` systems, and false - /// elsewhere. - fn should_register_signals(&self) -> bool; - - /// returns whether the server should set TCP_NODELAY on the - /// TcpListener, if that is applicable - fn nodelay(&self) -> bool; - - /// returns a clone of the [`Swansong`] associated with - /// this server, to be used in conjunction with signals or other - /// service interruption methods - fn swansong(&self) -> Swansong; - - /// returns the tls acceptor for this server - fn acceptor(&self) -> &AcceptorType; - - /// waits for all requests to complete - fn graceful_shutdown(&self) -> impl Future + Send; - - /// apply the provided handler to the transport, using - /// [`trillium_http`]'s http implementation. this is the default inner - /// loop for most trillium servers - fn handle_stream( - self: Arc, - stream: ServerType::Transport, - handler: impl Handler, - ) -> impl Future + Send; - - /// builds any type that is TryFrom and - /// configures it for use. most trillium servers should use this if - /// possible instead of using [`ConfigExt::port`], - /// [`ConfigExt::host`], or [`ConfigExt::socket_addrs`]. - /// - /// this function also contains logic that sets nonblocking to - /// true and on unix systems will build a tcp listener from the - /// `LISTEN_FD` env var. - fn build_listener(&self) -> Listener - where - Listener: TryFrom, - >::Error: std::fmt::Debug; - - /// determines if the server is currently responding to more than - /// the maximum number of connections set by - /// `Config::with_max_connections`. - fn over_capacity(&self) -> bool; -} - -impl ConfigExt - for Config -where - ServerType: Server + Send + ?Sized, - AcceptorType: Acceptor<::Transport>, -{ - fn port(&self) -> u16 { - self.port - .or_else(|| std::env::var("PORT").ok().and_then(|p| p.parse().ok())) - .unwrap_or(8080) - } - - fn host(&self) -> String { - self.host - .as_ref() - .map(String::from) - .or_else(|| std::env::var("HOST").ok()) - .unwrap_or_else(|| String::from("localhost")) - } - - fn socket_addrs(&self) -> Vec { - (self.host(), self.port()) - .to_socket_addrs() - .unwrap() - .collect() - } - - fn should_register_signals(&self) -> bool { - self.register_signals - } - - fn nodelay(&self) -> bool { - self.nodelay - } - - fn swansong(&self) -> Swansong { - self.swansong.clone() - } - - fn acceptor(&self) -> &AcceptorType { - &self.acceptor - } - - async fn graceful_shutdown(&self) { - self.swansong.shut_down().await - } - - async fn handle_stream( - self: Arc, - mut stream: ServerType::Transport, - handler: impl Handler, - ) { - if self.over_capacity() { - let mut byte = [0u8]; // wait for the client to start requesting - trillium::log_error!(stream.read(&mut byte).await); - trillium::log_error!(stream.write_all(SERVICE_UNAVAILABLE).await); - return; - } - - let guard = self.swansong.guard(); - - trillium::log_error!(stream.set_nodelay(self.nodelay)); - - let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip()); - - let transport = match self.acceptor.accept(stream).await { - Ok(stream) => stream, - Err(e) => { - log::error!("acceptor error: {:?}", e); - return; - } - }; - - let handler = &handler; - let result = trillium_http::Conn::map_with_config( - self.http_config, - transport, - self.swansong.clone(), - |mut conn| async { - conn.set_peer_ip(peer_ip); - let conn = handler.run(conn.into()).await; - let conn = handler.before_send(conn).await; - - conn.into_inner() - }, - ) - .await; - - match result { - Ok(Some(upgrade)) => { - let upgrade = upgrade.map_transport(BoxedTransport::new); - if handler.has_upgrade(&upgrade) { - log::debug!("upgrading..."); - handler.upgrade(upgrade).await; - } else { - log::error!("upgrade specified but no upgrade handler provided"); - } - } - - Err(Error::Closed) | Ok(None) => { - log::debug!("closing connection"); - } - - Err(Error::Io(e)) - if e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe => - { - log::debug!("closing connection"); - } - - Err(e) => { - log::error!("http error: {:?}", e); - } - }; - - drop(guard); - } - - fn build_listener(&self) -> Listener - where - Listener: TryFrom, - >::Error: std::fmt::Debug, - { - #[cfg(unix)] - let listener = { - use std::os::unix::prelude::FromRawFd; - - if let Some(fd) = std::env::var("LISTEN_FD") - .ok() - .and_then(|fd| fd.parse().ok()) - { - log::debug!("using fd {} from LISTEN_FD", fd); - unsafe { TcpListener::from_raw_fd(fd) } - } else { - TcpListener::bind((self.host(), self.port())).unwrap() - } - }; - - #[cfg(not(unix))] - let listener = TcpListener::bind((self.host(), self.port())).unwrap(); - - listener.set_nonblocking(true).unwrap(); - listener.try_into().unwrap() - } - - fn over_capacity(&self) -> bool { - self.max_connections - .map_or(false, |m| self.swansong.guard_count() >= m) - } -} diff --git a/server-common/src/lib.rs b/server-common/src/lib.rs index ca755b301c..e0fb60c078 100644 --- a/server-common/src/lib.rs +++ b/server-common/src/lib.rs @@ -1,3 +1,4 @@ +#![forbid(unsafe_code)] #![deny( clippy::dbg_macro, missing_copy_implementations, @@ -31,9 +32,6 @@ pub use url::Url; mod config; pub use config::Config; -mod config_ext; -pub use config_ext::ConfigExt; - mod server; pub use server::Server; @@ -56,3 +54,5 @@ pub use swansong::Swansong; mod runtime; pub use runtime::{DroppableFuture, Runtime, RuntimeTrait}; + +mod running_config; diff --git a/server-common/src/running_config.rs b/server-common/src/running_config.rs new file mode 100644 index 0000000000..f10d11bca7 --- /dev/null +++ b/server-common/src/running_config.rs @@ -0,0 +1,135 @@ +use crate::{Acceptor, ArcHandler, RuntimeTrait, Server}; +use futures_lite::{AsyncReadExt, AsyncWriteExt}; +use std::{io::ErrorKind, sync::Arc}; +use trillium::Handler; +use trillium_http::{ + transport::{BoxedTransport, Transport}, + Error, ServerConfig, SERVICE_UNAVAILABLE, +}; + +#[derive(Debug)] +pub struct RunningConfig { + pub(crate) acceptor: AcceptorType, + pub(crate) max_connections: Option, + pub(crate) nodelay: bool, + pub(crate) runtime: ServerType::Runtime, + pub(crate) server_config: Arc, +} + +impl::Transport>> RunningConfig { + pub(crate) async fn run_async(self: Arc, mut listener: S, handler: impl Handler) { + let swansong = self.server_config.as_ref().swansong(); + let runtime = self.runtime.clone(); + let handler = ArcHandler::new(handler); + while let Some(transport) = swansong.interrupt(listener.accept()).await { + match transport { + Ok(stream) => { + runtime.spawn( + Arc::clone(&self).handle_stream(stream, ArcHandler::clone(&handler)), + ); + } + Err(e) => log::error!("tcp error: {}", e), + } + } + + self.server_config.swansong().shut_down().await; + listener.clean_up().await; + } + + async fn handle_stream(self: Arc, mut stream: S::Transport, handler: impl Handler) { + if self.over_capacity() { + let mut byte = [0u8]; // wait for the client to start requesting + trillium::log_error!(stream.read(&mut byte).await); + trillium::log_error!(stream.write_all(SERVICE_UNAVAILABLE).await); + return; + } + + let guard = self.server_config.swansong().guard(); + + trillium::log_error!(stream.set_nodelay(self.nodelay)); + + let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip()); + + let transport = match self.acceptor.accept(stream).await { + Ok(stream) => stream, + Err(e) => { + log::error!("acceptor error: {:?}", e); + return; + } + }; + + let handler = &handler; + + let result = self + .server_config + .clone() + .run(transport, |mut conn| async { + conn.set_peer_ip(peer_ip); + let conn = handler.run(conn.into()).await; + let conn = handler.before_send(conn).await; + + conn.into_inner() + }) + .await; + + match result { + Ok(Some(upgrade)) => { + let upgrade = upgrade.map_transport(BoxedTransport::new); + if handler.has_upgrade(&upgrade) { + log::debug!("upgrading..."); + handler.upgrade(upgrade).await; + } else { + log::error!("upgrade specified but no upgrade handler provided"); + } + } + + Err(Error::Closed) | Ok(None) => { + log::debug!("closing connection"); + } + + Err(Error::Io(e)) + if e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe => + { + log::debug!("closing connection"); + } + + Err(e) => { + log::error!("http error: {:?}", e); + } + }; + + drop(guard); + } + + // fn build_listener(&self) -> Listener + // where + // Listener: TryFrom, + // >::Error: std::fmt::Debug, + // { + // #[cfg(unix)] + // let listener = { + // use std::os::unix::prelude::FromRawFd; + + // if let Some(fd) = std::env::var("LISTEN_FD") + // .ok() + // .and_then(|fd| fd.parse().ok()) + // { + // log::debug!("using fd {} from LISTEN_FD", fd); + // unsafe { TcpListener::from_raw_fd(fd) } + // } else { + // TcpListener::bind((self.host(), self.port())).unwrap() + // } + // }; + + // #[cfg(not(unix))] + // let listener = TcpListener::bind((self.host(), self.port())).unwrap(); + + // listener.set_nonblocking(true).unwrap(); + // listener.try_into().unwrap() + // } + + fn over_capacity(&self) -> bool { + self.max_connections + .map_or(false, |m| self.server_config.swansong().guard_count() >= m) + } +} diff --git a/server-common/src/runtime.rs b/server-common/src/runtime.rs index 82a049c283..c5919e75b9 100644 --- a/server-common/src/runtime.rs +++ b/server-common/src/runtime.rs @@ -22,16 +22,20 @@ pub struct Runtime(Arc); impl Debug for Runtime { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_tuple("Runtime").field(&"..").finish() + f.debug_tuple("Runtime").field(&format_args!("..")).finish() } } impl Runtime { /// Construct a new type-erased runtime object from any [`RuntimeTrait`] implementation. - /// - /// Prefer using [`from`][From::from]/[`into`][Into::into] if you don't have a concrete - /// `RuntimeTrait` in order to avoid double-arc-ing a Runtime. pub fn new(runtime: impl RuntimeTrait) -> Self { + runtime.into() // we avoid re-arcing a Runtime by using Into::into + } + + // in order to avoid re-arcing Runtime in new / into, we use this to actually construct the + // Runtime within From implementations on the runtime trait type + #[doc(hidden)] + pub fn from_trait_impl(runtime: impl RuntimeTrait) -> Self { Self(Arc::new(runtime)) } @@ -117,4 +121,11 @@ impl RuntimeTrait for Runtime { })); receive.recv().unwrap() } + + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + self.0.hook_signals(signals.into_iter().collect()) + } } diff --git a/server-common/src/runtime/object_safe_runtime.rs b/server-common/src/runtime/object_safe_runtime.rs index 5fe22c4dab..42b848952f 100644 --- a/server-common/src/runtime/object_safe_runtime.rs +++ b/server-common/src/runtime/object_safe_runtime.rs @@ -19,6 +19,8 @@ pub(super) trait ObjectSafeRuntime: Send + Sync + 'static { where 'runtime: 'fut, Self: 'fut; + + fn hook_signals(&self, signals: Vec) -> Pin + Send + 'static>>; } impl ObjectSafeRuntime for R @@ -54,4 +56,8 @@ where { RuntimeTrait::block_on(self, fut) } + + fn hook_signals(&self, signals: Vec) -> Pin + Send + 'static>> { + Box::pin(RuntimeTrait::hook_signals(self, signals)) + } } diff --git a/server-common/src/runtime/runtime_trait.rs b/server-common/src/runtime/runtime_trait.rs index 5889e1314e..48139955bd 100644 --- a/server-common/src/runtime/runtime_trait.rs +++ b/server-common/src/runtime/runtime_trait.rs @@ -51,4 +51,13 @@ pub trait RuntimeTrait: Into + Clone + Send + Sync + 'static { None }) } + + /// trap and return a [`Stream`] of signals that match the provided signals + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + let _ = signals; + futures_lite::stream::empty() + } } diff --git a/server-common/src/server.rs b/server-common/src/server.rs index 14f903f0da..30c4d62d3e 100644 --- a/server-common/src/server.rs +++ b/server-common/src/server.rs @@ -1,6 +1,10 @@ -use crate::{Acceptor, ArcHandler, Config, ConfigExt, RuntimeTrait, Swansong, Transport}; -use std::{future::Future, io::Result, sync::Arc}; -use trillium::{Handler, Info}; +use crate::{RuntimeTrait, Swansong, Transport}; +use listenfd::ListenFd; +use std::{future::Future, io::Result, net::TcpListener}; +use trillium::Info; + +#[cfg(unix)] +use std::os::unix::net::UnixListener; /** The server trait, for standard network-based server implementations. @@ -14,16 +18,15 @@ pub trait Server: Sized + Send + Sync + 'static { /// The [`RuntimeTrait`] for this `Server`. type Runtime: RuntimeTrait; - /// The description of this server, to be appended to the Info and potentially logged. - const DESCRIPTION: &'static str; - /// Asynchronously return a single `Self::Transport` from a /// `Self::Listener`. Must be implemented. fn accept(&mut self) -> impl Future> + Send; /// Build an [`Info`] from the Self::Listener type. See [`Info`] /// for more details. - fn info(&self) -> Info; + fn init(&self, info: &mut Info) { + let _ = info; + } /// After the server has shut down, perform any housekeeping, eg /// unlinking a unix socket. @@ -38,61 +41,59 @@ pub trait Server: Sized + Send + Sync + 'static { /// is described elsewhere. To override the default logic, server /// implementations could potentially implement this directly. To /// use this default logic, implement - /// [`Server::listener_from_tcp`] and - /// [`Server::listener_from_unix`]. - #[cfg(unix)] - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - if let Some(listener) = config.binding.write().unwrap().take() { - log::debug!("taking prebound listener"); - return listener; + /// [`Server::from_tcp`] and + /// [`Server::from_unix`]. + fn from_host_and_port(host: &str, port: u16) -> Self { + #[cfg(unix)] + if host.starts_with(|c| c == '/' || c == '.' || c == '~') { + log::debug!("using unix listener at {host}"); + return UnixListener::bind(host) + .inspect(|unix_listener| { + log::debug!("listening at {:?}", unix_listener.local_addr().unwrap()); + }) + .map(Self::from_unix) + .unwrap(); } - use std::os::unix::prelude::FromRawFd; - let host = config.host(); - if host.starts_with(|c| c == '/' || c == '.' || c == '~') { - Self::listener_from_unix(std::os::unix::net::UnixListener::bind(host).unwrap()) - } else { - let tcp_listener = if let Some(fd) = std::env::var("LISTEN_FD") - .ok() - .and_then(|fd| fd.parse().ok()) - { - log::debug!("using fd {} from LISTEN_FD", fd); - unsafe { std::net::TcpListener::from_raw_fd(fd) } - } else { - std::net::TcpListener::bind((host, config.port())).unwrap() - }; + let mut listen_fd = ListenFd::from_env(); - tcp_listener.set_nonblocking(true).unwrap(); - Self::listener_from_tcp(tcp_listener) + #[cfg(unix)] + if let Ok(Some(unix_listener)) = listen_fd.take_unix_listener(0) { + log::debug!( + "using unix listener from systemfd environment {:?}", + unix_listener.local_addr().unwrap() + ); + return Self::from_unix(unix_listener); } - } - /// Build a listener from the config. The default logic for this - /// is described elsewhere. To override the default logic, server - /// implementations could potentially implement this directly. To - /// use this default logic, implement [`Server::listener_from_tcp`] - #[cfg(not(unix))] - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - if let Some(listener) = config.binding.write().unwrap().take() { - log::debug!("taking prebound listener"); - return listener; - } + let tcp_listener = listen_fd + .take_tcp_listener(0) + .ok() + .flatten() + .inspect(|tcp_listener| { + log::debug!( + "using tcp listener from systemfd environment, listening at {:?}", + tcp_listener.local_addr() + ) + }) + .unwrap_or_else(|| { + log::debug!("using tcp listener at {host}:{port}"); + TcpListener::bind((host, port)) + .inspect(|tcp_listener| { + log::debug!("listening at {:?}", tcp_listener.local_addr().unwrap()) + }) + .unwrap() + }); - let tcp_listener = std::net::TcpListener::bind((config.host(), config.port())).unwrap(); tcp_listener.set_nonblocking(true).unwrap(); - Self::listener_from_tcp(tcp_listener) + Self::from_tcp(tcp_listener) } /// Build a Self::Listener from a tcp listener. This is called by /// the [`Server::build_listener`] default implementation, and /// is mandatory if the default implementation is used. - fn listener_from_tcp(_tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp_listener: TcpListener) -> Self { + let _ = tcp_listener; unimplemented!() } @@ -100,7 +101,8 @@ pub trait Server: Sized + Send + Sync + 'static { /// the [`Server::build_listener`] default implementation. You /// will want to tag an implementation of this with #[cfg(unix)]. #[cfg(unix)] - fn listener_from_unix(_tcp: std::os::unix::net::UnixListener) -> Self { + fn from_unix(unix_listener: UnixListener) -> Self { + let _ = unix_listener; unimplemented!() } @@ -110,58 +112,4 @@ pub trait Server: Sized + Send + Sync + 'static { fn handle_signals(_swansong: Swansong) -> impl Future + Send { async {} } - - /// Run a trillium application from a sync context - fn run(config: Config, handler: H) - where - A: Acceptor, - H: Handler, - { - config - .runtime - .clone() - .block_on(async move { Self::run_async(config, handler).await }); - } - - /// Run a trillium application from an async context. The default - /// implementation of this method contains the core logic of this - /// Trait. - fn run_async(config: Config, mut handler: H) -> impl Future + Send - where - A: Acceptor, - H: Handler, - { - async move { - let runtime = config.runtime.clone(); - if config.should_register_signals() { - #[cfg(unix)] - runtime.spawn(Self::handle_signals(config.swansong())); - - #[cfg(not(unix))] - log::error!("signals handling not supported on windows yet"); - } - let mut listener = Self::build_listener(&config); - let mut info = Self::info(&listener); - info.server_description_mut().push_str(Self::DESCRIPTION); - handler.init(&mut info).await; - config.info.set(Arc::new(info)); - let config = Arc::new(config); - let handler = ArcHandler::new(handler); - let swansong = &config.swansong; - - while let Some(transport) = swansong.interrupt(Self::accept(&mut listener)).await { - match transport { - Ok(stream) => { - let config = Arc::clone(&config); - let handler = ArcHandler::clone(&handler); - runtime.spawn(config.handle_stream(stream, handler)); - } - Err(e) => log::error!("tcp error: {}", e), - } - } - - config.graceful_shutdown().await; - Self::clean_up(listener).await; - } - } } diff --git a/server-common/src/server_handle.rs b/server-common/src/server_handle.rs index 3062c63f5c..c4934d0cf4 100644 --- a/server-common/src/server_handle.rs +++ b/server-common/src/server_handle.rs @@ -1,8 +1,8 @@ use crate::Runtime; use async_cell::sync::AsyncCell; -use std::{cell::OnceCell, future::IntoFuture, sync::Arc}; +use std::{cell::OnceCell, future::IntoFuture, net::SocketAddr, sync::Arc}; use swansong::{ShutdownCompletion, Swansong}; -use trillium::Info; +use trillium_http::ServerConfig; /// A handle for a spawned trillium server. Returned by /// [`Config::handle`][crate::Config::handle] and @@ -10,19 +10,48 @@ use trillium::Info; #[derive(Clone, Debug)] pub struct ServerHandle { pub(crate) swansong: Swansong, - pub(crate) info: Arc>>, - pub(crate) received_info: OnceCell>, + pub(crate) server_config: Arc>>, + pub(crate) received_server_config: OnceCell>, pub(crate) runtime: Runtime, } +#[derive(Debug)] +pub struct BoundInfo(Arc); + +impl BoundInfo { + /// Borrow a type from the [`TypeSet`] on this `BoundInfo`. + pub fn state(&self) -> Option<&T> { + self.0.shared_state().get() + } + + /// Returns the `local_addr` of a bound tcp listener, if such a thing exists for this server + pub fn tcp_socket_addr(&self) -> Option<&SocketAddr> { + self.state() + } + + pub fn url(&self) -> Option<&url::Url> { + self.state() + } + + /// Returns the `local_addr` of a bound unix listener, if such a thing exists for this server + #[cfg(unix)] + pub fn unix_socket_addr(&self) -> Option<&std::os::unix::net::SocketAddr> { + self.state() + } +} + impl ServerHandle { /// await server start and retrieve the server's [`Info`] - pub async fn info(&self) -> &Info { - if let Some(info) = self.received_info.get() { - return info; + pub async fn info(&self) -> BoundInfo { + if let Some(server_config) = self.received_server_config.get().cloned() { + return BoundInfo(server_config); } - let arc_info = self.info.get().await; - self.received_info.get_or_init(|| arc_info) + let arc_server_config = self.server_config.get().await; + let server_config = self + .received_server_config + .get_or_init(|| arc_server_config); + + BoundInfo(Arc::clone(server_config)) } /// stop server and return a future that can be awaited for it to shut down gracefully diff --git a/smol/examples/smol.rs b/smol/examples/smol.rs index 2d32e8ff1e..180617b5ce 100644 --- a/smol/examples/smol.rs +++ b/smol/examples/smol.rs @@ -1,11 +1,11 @@ use std::time::Duration; use trillium::{Conn, Handler}; use trillium_logger::Logger; -use trillium_smol::SmolRuntime; +use trillium_server_common::Runtime; pub fn app() -> impl Handler { (Logger::new(), |conn: Conn| async move { - let runtime = SmolRuntime::default(); + let runtime = conn.shared_state::().cloned().unwrap(); let response = runtime .clone() .spawn(async move { @@ -23,13 +23,3 @@ pub fn main() { env_logger::init(); trillium_smol::run(app()); } - -#[cfg(test)] -mod tests { - use trillium_testing::prelude::*; - #[test] - fn test() { - let app = super::app(); - assert_ok!(get("/").on(&app), "successfully spawned a task"); - } -} diff --git a/smol/src/runtime.rs b/smol/src/runtime.rs index 12396a3e36..cd34382236 100644 --- a/smol/src/runtime.rs +++ b/smol/src/runtime.rs @@ -54,6 +54,14 @@ impl RuntimeTrait for SmolRuntime { fn block_on(&self, fut: Fut) -> Fut::Output { async_global_executor::block_on(fut) } + + #[cfg(unix)] + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + signal_hook_async_std::Signals::new(signals).unwrap() + } } impl SmolRuntime { @@ -105,6 +113,6 @@ impl SmolRuntime { impl From for Runtime { fn from(value: SmolRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } diff --git a/smol/src/server/tcp.rs b/smol/src/server/tcp.rs index 7ff8dd2711..b200959354 100644 --- a/smol/src/server/tcp.rs +++ b/smol/src/server/tcp.rs @@ -1,8 +1,8 @@ use crate::{SmolRuntime, SmolTransport}; use async_net::{TcpListener, TcpStream}; -use std::{convert::TryInto, env, io::Result, net}; +use std::{convert::TryInto, io::Result, net}; use trillium::Info; -use trillium_server_common::{Server, Url}; +use trillium_server_common::Server; #[derive(Debug)] pub struct SmolTcpServer(TcpListener); @@ -16,29 +16,18 @@ impl Server for SmolTcpServer { type Transport = SmolTransport; type Runtime = SmolRuntime; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - async fn accept(&mut self) -> Result { self.0.accept().await.map(|(t, _)| t.into()) } - fn listener_from_tcp(tcp: net::TcpListener) -> Self { + fn from_tcp(tcp: net::TcpListener) -> Self { Self(tcp.try_into().unwrap()) } - fn info(&self) -> Info { - let local_addr = self.0.local_addr().unwrap(); - let mut info = Info::from(local_addr); - if let Ok(url) = Url::parse(&format!("http://{local_addr}")) { - info.state_mut().insert(url); + fn init(&self, info: &mut Info) { + if let Ok(socket_addr) = self.0.local_addr() { + info.insert_state(socket_addr); } - info } fn runtime() -> Self::Runtime { diff --git a/smol/src/server/unix.rs b/smol/src/server/unix.rs index 36cea416cb..64810c355a 100644 --- a/smol/src/server/unix.rs +++ b/smol/src/server/unix.rs @@ -3,12 +3,11 @@ use async_net::{ unix::{UnixListener, UnixStream}, TcpListener, TcpStream, }; -use futures_lite::prelude::*; -use std::{env, io::Result}; +use std::io::Result; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, Url, + Server, }; #[derive(Debug, Clone)] @@ -28,35 +27,11 @@ impl From for SmolServer { impl Server for SmolServer { type Transport = Binding, SmolTransport>; type Runtime = SmolRuntime; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); fn runtime() -> Self::Runtime { SmolRuntime::default() } - async fn handle_signals(swansong: Swansong) { - use signal_hook::consts::signal::*; - use signal_hook_async_std::Signals; - - let signals = Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - let mut signals = signals.fuse(); - while signals.next().await.is_some() { - if swansong.state().is_shutting_down() { - eprintln!("\nSecond interrupt, shutting down harshly"); - std::process::exit(1); - } else { - println!("\nShutting down gracefully.\nControl-C again to force."); - swansong.shut_down(); - } - } - } - async fn accept(&mut self) -> Result { match &self.0 { Tcp(t) => t.accept().await.map(|(t, _)| Tcp(SmolTransport::from(t))), @@ -64,25 +39,26 @@ impl Server for SmolServer { } } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp: std::net::TcpListener) -> Self { Self(Tcp(tcp.try_into().unwrap())) } - fn listener_from_unix(tcp: std::os::unix::net::UnixListener) -> Self { + fn from_unix(tcp: std::os::unix::net::UnixListener) -> Self { Self(Unix(tcp.try_into().unwrap())) } - fn info(&self) -> Info { + fn init(&self, info: &mut Info) { match &self.0 { Tcp(t) => { - let local_addr = t.local_addr().unwrap(); - let mut info = Info::from(local_addr); - if let Ok(url) = Url::parse(&format!("http://{local_addr}")) { - info.state_mut().insert(url); + if let Ok(socket_addr) = t.local_addr() { + info.insert_state(socket_addr); + } + } + Unix(u) => { + if let Ok(socket_addr) = u.local_addr() { + info.insert_state(socket_addr); } - info } - Unix(u) => u.local_addr().unwrap().into(), } } diff --git a/static/src/handler.rs b/static/src/handler.rs index da51a189b3..6b502e44a4 100644 --- a/static/src/handler.rs +++ b/static/src/handler.rs @@ -81,7 +81,7 @@ impl StaticFileHandler { use trillium_testing::prelude::*; let mut handler = StaticFileHandler::new(crate_relative_path!("examples/files")); - # handler.init(&mut "testing".into()).await; + # init(&mut handler); assert_not_handled!(get("/").run_async(&handler).await); // no index file configured @@ -127,7 +127,7 @@ impl StaticFileHandler { let mut handler = StaticFileHandler::new(crate_relative_path!("examples/files")) .with_index_file("index.html"); - # handler.init(&mut "testing".into()).await; + # init(&mut handler); use trillium_testing::prelude::*; assert_ok!( diff --git a/static/src/lib.rs b/static/src/lib.rs index a78d8d96a6..9a72ccf7d8 100644 --- a/static/src/lib.rs +++ b/static/src/lib.rs @@ -37,7 +37,7 @@ let mut handler = StaticFileHandler::new(crate_relative_path!("examples/files")) use trillium_testing::prelude::*; # use trillium::Handler; -handler.init(&mut "testing".into()).await; +# init(&mut handler); assert_ok!( get("/").run_async(&handler).await, diff --git a/tera/src/tera_handler.rs b/tera/src/tera_handler.rs index f819ed0166..4f547d4c26 100644 --- a/tera/src/tera_handler.rs +++ b/tera/src/tera_handler.rs @@ -22,13 +22,13 @@ impl From<&str> for TeraHandler { impl From<&String> for TeraHandler { fn from(dir: &String) -> Self { - (**dir).into() + Tera::new(&dir).unwrap().into() } } impl From for TeraHandler { fn from(dir: String) -> Self { - dir.into() + Tera::new(&dir).unwrap().into() } } diff --git a/testing/Cargo.toml b/testing/Cargo.toml index 7ae2fb3c23..f2c229cf47 100644 --- a/testing/Cargo.toml +++ b/testing/Cargo.toml @@ -32,6 +32,7 @@ trillium-macros = { version = "0.0.6", path = "../macros" } dashmap = "5.5.3" once_cell = "1.19.0" fastrand = "2.0.1" +env_logger = "0.11.3" log = "0.4.21" [dependencies.trillium-smol] diff --git a/testing/src/lib.rs b/testing/src/lib.rs index df2ccbfe47..cd111ef967 100644 --- a/testing/src/lib.rs +++ b/testing/src/lib.rs @@ -65,8 +65,7 @@ trillium-testing = { version = "0.2", features = ["smol"] } mod assertions; mod test_transport; -use std::future::Future; -use std::process::Termination; +use std::{future::Future, process::Termination, sync::Arc}; pub use test_transport::TestTransport; @@ -86,8 +85,10 @@ pub mod prelude { pub use trillium::{Conn, Method, Status}; } +use trillium::{Handler, Info}; pub use trillium::{Method, Status}; +use trillium_http::ServerConfig; pub use url::Url; /// runs the future to completion on the current thread @@ -96,9 +97,12 @@ pub fn block_on(fut: Fut) -> Fut::Output { } /// initialize a handler -pub fn init(handler: &mut impl trillium::Handler) { - let mut info = "testing".into(); - block_on(async move { handler.init(&mut info).await }) +pub fn init(handler: &mut impl Handler) -> Arc { + let mut info = Info::from(ServerConfig::default()); + info.insert_state(runtime()); + info.insert_state(runtime().into()); + block_on(handler.init(&mut info)); + Arc::new(info.into()) } // these exports are used by macros @@ -198,5 +202,18 @@ where Fut: Future, Output: Termination, { + let _ = env_logger::builder().is_test(true).try_init(); block_on(test()) } + +/// a harness that includes the runtime +#[track_caller] +pub fn with_runtime(test: F) -> Output +where + F: FnOnce(Runtime) -> Fut, + Fut: Future, + Output: Termination, +{ + let runtime = runtime(); + runtime.clone().block_on(test(runtime.into())) +} diff --git a/testing/src/runtimeless/runtime.rs b/testing/src/runtimeless/runtime.rs index 13901df928..495279e75b 100644 --- a/testing/src/runtimeless/runtime.rs +++ b/testing/src/runtimeless/runtime.rs @@ -48,7 +48,7 @@ impl RuntimeTrait for RuntimelessRuntime { } impl From for Runtime { fn from(value: RuntimelessRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } impl RuntimelessRuntime { diff --git a/testing/src/runtimeless/server.rs b/testing/src/runtimeless/server.rs index 597b176182..672b115246 100644 --- a/testing/src/runtimeless/server.rs +++ b/testing/src/runtimeless/server.rs @@ -3,7 +3,7 @@ use crate::{RuntimelessRuntime, TestTransport}; use async_channel::Receiver; use std::io::{Error, ErrorKind, Result}; use trillium::Info; -use trillium_server_common::{Acceptor, Config, ConfigExt, Server}; +use trillium_server_common::Server; use url::Url; /// A [`Server`] for testing that does not depend on any runtime @@ -30,8 +30,6 @@ impl Server for RuntimelessServer { type Transport = TestTransport; type Runtime = RuntimelessRuntime; - const DESCRIPTION: &'static str = "test server"; - fn runtime() -> Self::Runtime { RuntimelessRuntime::default() } @@ -43,29 +41,24 @@ impl Server for RuntimelessServer { .map_err(|e| Error::new(ErrorKind::Other, e.to_string())) } - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - let mut port = config.port(); - let host = config.host(); + fn from_host_and_port(host: &str, mut port: u16) -> Self { if port == 0 { loop { port = fastrand::u16(..); - if !SERVERS.contains_key(&(host.clone(), port)) { + if !SERVERS.contains_key(&(host.to_string(), port)) { break; } } } let entry = SERVERS - .entry((host.clone(), port)) + .entry((host.to_string(), port)) .or_insert_with(async_channel::unbounded); let (_, channel) = entry.value(); Self { - host, + host: host.to_string(), channel: channel.clone(), port, } @@ -75,10 +68,7 @@ impl Server for RuntimelessServer { SERVERS.remove(&(self.host, self.port)); } - fn info(&self) -> Info { - let mut info = Info::from(&*format!("{}:{}", &self.host, &self.port)); - info.state_mut() - .insert(Url::parse(&format!("http://{}:{}", &self.host, self.port)).unwrap()); - info + fn init(&self, info: &mut Info) { + info.insert_state(Url::parse(&format!("http://{}:{}", &self.host, self.port)).unwrap()); } } diff --git a/testing/src/server_connector.rs b/testing/src/server_connector.rs index 7b007b26f4..082a373ff4 100644 --- a/testing/src/server_connector.rs +++ b/testing/src/server_connector.rs @@ -1,7 +1,7 @@ use crate::{RuntimeType, TestTransport}; use std::{io, sync::Arc}; use trillium::Handler; -use trillium_http::Conn; +use trillium_http::ServerConfig; use trillium_server_common::Connector; use url::Url; @@ -10,6 +10,7 @@ use url::Url; pub struct ServerConnector { handler: Arc, runtime: RuntimeType, + server_config: Arc, } impl ServerConnector { @@ -18,27 +19,36 @@ impl ServerConnector { Self { handler: Arc::new(handler), runtime: RuntimeType::default(), + server_config: Arc::default(), } } + /// use a specific server config + pub fn with_server_config(mut self, server_config: ServerConfig) -> Self { + self.server_config = Arc::new(server_config); + self + } + /// opens a new connection to this virtual server, returning the client transport pub async fn connect(&self, secure: bool) -> TestTransport { let (client_transport, server_transport) = TestTransport::new(); let handler = Arc::clone(&self.handler); + let server_config = Arc::clone(&self.server_config); self.runtime.spawn(async move { - Conn::map(server_transport, Default::default(), |mut conn| { - let handler = Arc::clone(&handler); - async move { - conn.set_secure(secure); - let conn = handler.run(conn.into()).await; - let conn = handler.before_send(conn).await; - conn.into_inner() - } - }) - .await - .unwrap(); + server_config + .run(server_transport, |mut conn| { + let handler = Arc::clone(&handler); + async move { + conn.set_secure(secure); + let conn = handler.run(conn.into()).await; + let conn = handler.before_send(conn).await; + conn.into_inner() + } + }) + .await + .unwrap(); }); client_transport diff --git a/testing/src/test_conn.rs b/testing/src/test_conn.rs index 6afc5f07bb..b3ad9ee15f 100644 --- a/testing/src/test_conn.rs +++ b/testing/src/test_conn.rs @@ -2,9 +2,10 @@ use std::{ fmt::Debug, net::IpAddr, ops::{Deref, DerefMut}, + sync::Arc, }; use trillium::{Conn, Handler, HeaderName, HeaderValues, Method}; -use trillium_http::{Conn as HttpConn, Synthetic}; +use trillium_http::{Conn as HttpConn, ServerConfig, Synthetic}; type SyntheticConn = HttpConn; @@ -35,6 +36,17 @@ impl TestConn { Self(HttpConn::new_synthetic(method.try_into().unwrap(), path.into(), body).into()) } + /// assigns a shared server config to this test conn + pub fn with_server_config(self, server_config: Arc) -> Self { + let inner = self + .0 + .into_inner::() + .with_server_config(server_config) + .into(); + + Self(inner) + } + /** chainable constructor to append a request header to the TestConn ``` diff --git a/testing/src/with_server.rs b/testing/src/with_server.rs index 4b7234acdf..d55447e3a3 100644 --- a/testing/src/with_server.rs +++ b/testing/src/with_server.rs @@ -26,7 +26,7 @@ where runtime.block_on(async move { let handle = config.spawn(handler); let info = handle.info().await; - let url = info.state().get::().cloned().unwrap_or_else(|| { + let url = info.state().cloned().unwrap_or_else(|| { let port = info.tcp_socket_addr().map(|t| t.port()).unwrap_or(0); format!("http://localhost:{port}").parse().unwrap() }); diff --git a/tokio/src/runtime.rs b/tokio/src/runtime.rs index 5520d0ee96..bc69153646 100644 --- a/tokio/src/runtime.rs +++ b/tokio/src/runtime.rs @@ -55,6 +55,14 @@ impl RuntimeTrait for TokioRuntime { Inner::Owned(runtime) => runtime.block_on(fut), } } + + #[cfg(unix)] + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + signal_hook_tokio::Signals::new(signals).unwrap() + } } impl TokioRuntime { @@ -112,6 +120,6 @@ impl TokioRuntime { impl From for Runtime { fn from(value: TokioRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } diff --git a/tokio/src/server/tcp.rs b/tokio/src/server/tcp.rs index ddc0201f5c..29e4092967 100644 --- a/tokio/src/server/tcp.rs +++ b/tokio/src/server/tcp.rs @@ -18,13 +18,6 @@ impl From for TokioServer { impl Server for TokioServer { type Runtime = TokioRuntime; type Transport = TokioTransport>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); async fn accept(&mut self) -> io::Result { self.0 @@ -33,12 +26,14 @@ impl Server for TokioServer { .map(|(t, _)| TokioTransport(Compat::new(t))) } - fn listener_from_tcp(tcp: net::TcpListener) -> Self { + fn from_tcp(tcp: net::TcpListener) -> Self { Self(tcp.try_into().unwrap()) } - fn info(&self) -> Info { - self.0.local_addr().unwrap().into() + fn init(&self, info: &mut Info) { + if let Ok(socket_addr) = self.0.local_addr() { + info.insert_state(socket_addr); + } } fn runtime() -> Self::Runtime { diff --git a/tokio/src/server/unix.rs b/tokio/src/server/unix.rs index 674da9a4c3..a1208c8ad0 100644 --- a/tokio/src/server/unix.rs +++ b/tokio/src/server/unix.rs @@ -5,7 +5,7 @@ use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, + Server, }; /// Tcp/Unix Trillium server adapter for Tokio @@ -28,31 +28,6 @@ impl Server for TokioServer { type Runtime = TokioRuntime; type Transport = Binding>, TokioTransport>>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - - async fn handle_signals(swansong: Swansong) { - use signal_hook::consts::signal::*; - use signal_hook_tokio::Signals; - use tokio_stream::StreamExt; - let signals = Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - let mut signals = signals.fuse(); - while signals.next().await.is_some() { - if swansong.state().is_shutting_down() { - eprintln!("\nSecond interrupt, shutting down harshly"); - std::process::exit(1); - } else { - println!("\nShutting down gracefully.\nControl-C again to force."); - swansong.shut_down(); - } - } - } - async fn accept(&mut self) -> Result { match &mut self.0 { Tcp(t) => t @@ -67,19 +42,20 @@ impl Server for TokioServer { } } - fn info(&self) -> Info { + fn init(&self, info: &mut Info) { match &self.0 { - Tcp(t) => t.local_addr().unwrap().into(), - Unix(u) => (*format!("{:?}", u.local_addr().unwrap())).into(), - } - } - - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { - Self(Tcp(tcp.try_into().unwrap())) - } + Tcp(t) => { + if let Ok(socket_addr) = t.local_addr() { + info.insert_state(socket_addr); + } + } - fn listener_from_unix(unix: std::os::unix::net::UnixListener) -> Self { - Self(Unix(unix.try_into().unwrap())) + Unix(u) => { + if let Ok(socket_addr) = u.local_addr() { + info.insert_state(socket_addr); + } + } + } } async fn clean_up(self) { @@ -93,6 +69,14 @@ impl Server for TokioServer { } } + fn from_tcp(tcp_listener: std::net::TcpListener) -> Self { + TcpListener::from_std(tcp_listener).unwrap().into() + } + + fn from_unix(unix_listener: std::os::unix::net::UnixListener) -> Self { + UnixListener::from_std(unix_listener).unwrap().into() + } + fn runtime() -> Self::Runtime { TokioRuntime::default() } diff --git a/trillium/Cargo.toml b/trillium/Cargo.toml index 7c93497014..4605b426e7 100644 --- a/trillium/Cargo.toml +++ b/trillium/Cargo.toml @@ -25,10 +25,11 @@ futures-lite = "2.1.0" async-channel = "2.1.1" async-io = "2.3.1" fastrand = "2.0.1" -test-harness = "0.2.0" trillium-smol = { path = "../smol" } trillium-testing = { path = "../testing" } env_logger = "0.11.3" +trillium-client = { path = "../client" } +test-harness = "0.2.0" [package.metadata.cargo-udeps.ignore] development = ["trillium-testing"] diff --git a/trillium/examples/state.rs b/trillium/examples/state.rs index c3ef379429..e1f6b13db3 100644 --- a/trillium/examples/state.rs +++ b/trillium/examples/state.rs @@ -36,14 +36,28 @@ mod conn_counter { } } +use std::time::Instant; + use conn_counter::{ConnCounterConnExt, ConnCounterHandler}; -use trillium::{Conn, Handler}; +use trillium::{Conn, Handler, Init}; + +struct ServerStart(Instant); fn handler() -> impl Handler { - (ConnCounterHandler::new(), |conn: Conn| async move { - let conn_number = conn.conn_number(); - conn.ok(format!("conn number was {conn_number}")) - }) + ( + Init::new(|info| async move { info.with_state(ServerStart(Instant::now())) }), + ConnCounterHandler::new(), + |conn: Conn| async move { + let uptime = conn + .shared_state() + .map(|ServerStart(instant)| instant.elapsed()) + .unwrap_or_default(); + let conn_number = conn.conn_number(); + conn.ok(format!( + "conn number was {conn_number}, server has been up {uptime:?}" + )) + }, + ) } fn main() { diff --git a/trillium/src/conn.rs b/trillium/src/conn.rs index 6ef132eb62..6972ea6855 100644 --- a/trillium/src/conn.rs +++ b/trillium/src/conn.rs @@ -240,14 +240,6 @@ impl Conn { self.inner.state().get() } - /// Attempts to receive a &T from the shared state set - /// - /// Note that shared state may not currently be mutated after server start, so there is no - /// `shared_state_mut` or `shared_state_entry` - pub fn shared_state(&self) -> Option<&T> { - self.inner.shared_state().and_then(TypeSet::get) - } - /// Attempts to retrieve a &mut T from the state set pub fn state_mut(&mut self) -> Option<&mut T> { self.inner.state_mut().get_mut() @@ -275,14 +267,17 @@ impl Conn { self.inner.state_mut().take() } - /// Returns an [`Entry`] type that represents the presence or absence of a type in this state. - /// - /// Use this for chainable combinators like [`Entry::or_default`], [`Entry::or_insert`], - /// [`Entry::or_insert_with`], and [`Entry::and_modify`] as well as matching on it as an enum. + /// Returns an [`Entry`] for the state typeset that can be used with functions like + /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others. pub fn state_entry(&mut self) -> Entry<'_, T> { self.inner.state_mut().entry() } + /// Attempts to borrow a T from the immutable shared state set + pub fn shared_state(&self) -> Option<&T> { + self.inner.shared_state().get() + } + /** Returns a [`ReceivedBody`] that references this `Conn`. The `Conn` retains all data and holds the singular transport, but the diff --git a/trillium/src/handler.rs b/trillium/src/handler.rs index dfbfc32c91..ea316845c6 100644 --- a/trillium/src/handler.rs +++ b/trillium/src/handler.rs @@ -60,59 +60,46 @@ For most application code and even trillium-packaged framework code, */ pub trait Handler: Send + Sync + 'static { - /// Executes this handler, performing any modifications to the - /// Conn that are desired. - fn run(&self, conn: Conn) -> impl Future + Send; - - /** - Performs one-time async set up on a mutable borrow of the - Handler before the server starts accepting requests. This - allows a Handler to be defined in synchronous code but perform - async setup such as establishing a database connection or - fetching some state from an external source. This is optional, - and chances are high that you do not need this. - - It also receives a mutable borrow of the [`Info`] that represents - the current connection. - - **stability note:** This may go away at some point. Please open an - **issue if you have a use case which requires it. - */ - fn init(&mut self, _info: &mut Info) -> impl Future + Send { + /// Executes this handler, performing any modifications to the Conn that are desired. + fn run(&self, conn: Conn) -> impl Future + Send { + async { conn } + } + + /// Performs one-time async set up on a mutable borrow of the Handler before the server starts + /// accepting requests. This allows a Handler to be defined in synchronous code but perform + /// async setup such as establishing a database connection or fetching some state from an + /// external source. This is optional, and chances are high that you do not need this. + /// + /// It also receives a mutable borrow of the [`Info`] that represents the current connection. + fn init(&mut self, info: &mut Info) -> impl Future + Send { + let _ = info; std::future::ready(()) } - /** - Performs any final modifications to this conn after all handlers - have been run. Although this is a slight deviation from the simple - conn->conn->conn chain represented by most Handlers, it provides - an easy way for libraries to effectively inject a second handler - into a response chain. This is useful for loggers that need to - record information both before and after other handlers have run, - as well as database transaction handlers and similar library code. - - **❗IMPORTANT NOTE FOR LIBRARY AUTHORS:** Please note that this - will run __whether or not the conn has was halted before - [`Handler::run`] was called on a given conn__. This means that if - you want to make your `before_send` callback conditional on - whether `run` was called, you need to put a unit type into the - conn's state and check for that. - - stability note: I don't love this for the exact reason that it - breaks the simplicity of the conn->conn->model, but it is - currently the best compromise between that simplicity and - convenience for the application author, who should not have to add - two Handlers to achieve an "around" effect. - */ + /// Performs any final modifications to this conn after all handlers have been run. Although + /// this is a slight deviation from the simple conn->conn->conn chain represented by most + /// Handlers, it provides an easy way for libraries to effectively inject a second handler into + /// a response chain. This is useful for loggers that need to record information both before and + /// after other handlers have run, as well as database transaction handlers and similar library + /// code. + /// + /// **❗IMPORTANT NOTE FOR LIBRARY AUTHORS:** Please note that this will run __whether or not + /// the conn has was halted before [`Handler::run`] was called on a given conn__. This means + /// that if you want to make your `before_send` callback conditional on whether `run` was + /// called, you need to put a unit type into the conn's state and check for that. + /// + /// stability note: I don't love this for the exact reason that it breaks the simplicity of the + /// conn->conn->model, but it is currently the best compromise between that simplicity and + /// convenience for the application author, who should not have to add two Handlers to achieve + /// an "around" effect. fn before_send(&self, conn: Conn) -> impl Future + Send { std::future::ready(conn) } /** - predicate function answering the question of whether this Handler - would like to take ownership of the negotiated Upgrade. If this - returns true, you must implement [`Handler::upgrade`]. The first - handler that responds true to this will receive ownership of the + predicate function answering the question of whether this Handler would like to take ownership + of the negotiated Upgrade. If this returns true, you must implement [`Handler::upgrade`]. The + first handler that responds true to this will receive ownership of the [`trillium::Upgrade`][crate::Upgrade] in a subsequent call to [`Handler::upgrade`] */ fn has_upgrade(&self, upgrade: &Upgrade) -> bool { @@ -121,15 +108,12 @@ pub trait Handler: Send + Sync + 'static { } /** - This will only be called if the handler reponds true to - [`Handler::has_upgrade`] and will only be called once for this - upgrade. There is no return value, and this function takes - exclusive ownership of the underlying transport once this is - called. You can downcast the transport to whatever the source - transport type is and perform any non-http protocol communication - that has been negotiated. You probably don't want this unless - you're implementing something like websockets. Please note that - for many transports such as `TcpStreams`, dropping the transport + This will only be called if the handler reponds true to [`Handler::has_upgrade`] and will only + be called once for this upgrade. There is no return value, and this function takes exclusive + ownership of the underlying transport once this is called. You can downcast the transport to + whatever the source transport type is and perform any non-http protocol communication that has + been negotiated. You probably don't want this unless you're implementing something like + websockets. Please note that for many transports such as `TcpStreams`, dropping the transport (and therefore the Upgrade) will hang up / disconnect. */ fn upgrade(&self, upgrade: Upgrade) -> impl Future + Send { @@ -138,41 +122,14 @@ pub trait Handler: Send + Sync + 'static { } /** - Customize the name of your handler. This is used in Debug - implementations. The default is the type name of this handler. + Customize the name of your handler. This is used in Debug implementations. The default is the + type name of this handler. */ fn name(&self) -> Cow<'static, str> { std::any::type_name::().into() } } -// -// impl Handler for Box { -// async fn run(&self, conn: Conn) -> Conn { -// self.as_ref().run(conn).await -// } - -// async fn init(&mut self, info: &mut Info) { -// self.as_mut().init(info).await; -// } - -// async fn before_send(&self, conn: Conn) -> Conn { -// self.as_ref().before_send(conn).await -// } - -// fn name(&self) -> Cow<'static, str> { -// self.as_ref().name() -// } - -// fn has_upgrade(&self, upgrade: &Upgrade) -> bool { -// self.as_ref().has_upgrade(upgrade) -// } - -// async fn upgrade(&self, upgrade: Upgrade) { -// self.as_ref().upgrade(upgrade).await; -// } -// } - impl Handler for Status { async fn run(&self, conn: Conn) -> Conn { conn.with_status(*self) diff --git a/trillium/src/info.rs b/trillium/src/info.rs index b15d598a49..65022b9018 100644 --- a/trillium/src/info.rs +++ b/trillium/src/info.rs @@ -1,10 +1,5 @@ -use std::{ - fmt::{Display, Formatter, Result}, - net::SocketAddr, -}; -use trillium_http::TypeSet; - -const DEFAULT_SERVER_DESCRIPTION: &str = concat!("trillium v", env!("CARGO_PKG_VERSION")); +use std::net::SocketAddr; +use trillium_http::{type_set::entry::Entry, ServerConfig, Swansong, TypeSet}; /** This struct represents information about the currently connected @@ -13,130 +8,73 @@ server. It is passed to [`Handler::init`](crate::Handler::init). */ -#[derive(Debug)] -pub struct Info { - server_description: String, - listener_description: String, - tcp_socket_addr: Option, - state: TypeSet, -} - -impl Default for Info { - fn default() -> Self { - Self { - server_description: DEFAULT_SERVER_DESCRIPTION.into(), - listener_description: String::new(), - tcp_socket_addr: None, - state: TypeSet::new(), - } +#[derive(Debug, Default)] +pub struct Info(ServerConfig); +impl From for Info { + fn from(value: ServerConfig) -> Self { + Self(value) } } - -impl Display for Info { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.write_fmt(format_args!( - "{} listening on {}", - self.server_description(), - self.listener_description(), - )) +impl From for ServerConfig { + fn from(value: Info) -> Self { + value.0 } } -impl Info { - /// Returns a user-displayable description of the server. This - /// might be a string like "trillium x.y.z (trillium-tokio x.y.z)" or "my - /// special application". - pub fn server_description(&self) -> &str { - &self.server_description +impl AsRef for Info { + fn as_ref(&self) -> &TypeSet { + self.0.as_ref() } - - /// Returns a user-displayable string description of the location - /// or port the listener is bound to, potentially as a url. Do not - /// rely on the format of this string, as it will vary between - /// server implementations and is intended for user - /// display. Instead, use [`Info::tcp_socket_addr`] for any - /// processing. - pub fn listener_description(&self) -> &str { - &self.listener_description +} +impl AsMut for Info { + fn as_mut(&mut self) -> &mut TypeSet { + self.0.as_mut() } +} +impl Info { /// Returns the `local_addr` of a bound tcp listener, if such a /// thing exists for this server - pub const fn tcp_socket_addr(&self) -> Option<&SocketAddr> { - self.tcp_socket_addr.as_ref() - } - - /// obtain a mutable borrow of the server description, suitable - /// for appending information or replacing it - pub fn server_description_mut(&mut self) -> &mut String { - &mut self.server_description - } - - /// obtain a mutable borrow of the listener description, suitable - /// for appending information or replacing it - pub fn listener_description_mut(&mut self) -> &mut String { - &mut self.listener_description + pub fn tcp_socket_addr(&self) -> Option<&SocketAddr> { + self.state() } - /// borrow the [`TypeSet`] on this `Info`. This can be useful for passing initialization data - /// between handlers - #[allow(clippy::missing_const_for_fn)] // Info isn't useful in a const context - pub fn state(&self) -> &TypeSet { - &self.state + /// Returns the `local_addr` of a bound unix listener, if such a + /// thing exists for this server + #[cfg(unix)] + pub fn unix_socket_addr(&self) -> Option<&std::os::unix::net::SocketAddr> { + self.state() } - /// attempt to mutably borrow the [`TypeSet`] on this `Info`. - pub fn state_mut(&mut self) -> &mut TypeSet { - &mut self.state + /// Borrow a type from the shared state [`TypeSet`] on this `Info`. + pub fn state(&self) -> Option<&T> { + self.0.shared_state().get() } -} -impl AsRef for Info { - fn as_ref(&self) -> &TypeSet { - self.state() + /// Insert a type into the shared state typeset, returning the previous value if any + pub fn insert_state(&mut self, value: T) -> Option { + self.0.shared_state_mut().insert(value) } -} -impl AsMut for Info { - fn as_mut(&mut self) -> &mut TypeSet { - self.state_mut() + /// Mutate a type in the shared state typeset + pub fn state_mut(&mut self) -> Option<&mut T> { + self.0.shared_state_mut().get_mut() } -} -impl From<&str> for Info { - fn from(description: &str) -> Self { - Self { - server_description: String::from(DEFAULT_SERVER_DESCRIPTION), - listener_description: String::from(description), - ..Self::default() - } + /// Returns an [`Entry`] into the shared state typeset. + pub fn state_entry(&mut self) -> Entry<'_, T> { + self.0.shared_state_mut().entry() } -} - -impl From for Info { - fn from(socket_addr: SocketAddr) -> Self { - let mut info = Self { - server_description: String::from(DEFAULT_SERVER_DESCRIPTION), - listener_description: socket_addr.to_string(), - tcp_socket_addr: Some(socket_addr), - ..Self::default() - }; - info.state_mut().insert(socket_addr); - info + /// chainable interface to insert a type into the shared state typeset + #[must_use] + pub fn with_state(mut self, value: T) -> Self { + self.insert_state(value); + self } -} - -#[cfg(unix)] -impl From for Info { - fn from(s: std::os::unix::net::SocketAddr) -> Self { - let mut info = Self { - server_description: String::from(DEFAULT_SERVER_DESCRIPTION), - listener_description: format!("{s:?}"), - ..Self::default() - }; - info.state_mut().insert(s); - info + /// Borrow the [`Swansong`] graceful shutdown interface for this server + pub fn swansong(&self) -> &Swansong { + self.0.swansong() } } diff --git a/trillium/src/init.rs b/trillium/src/init.rs new file mode 100644 index 0000000000..78436baf0d --- /dev/null +++ b/trillium/src/init.rs @@ -0,0 +1,86 @@ +use crate::{Conn, Handler, Info}; +use std::{future::Future, mem}; + +/** + +Provides support for asynchronous initialization of a handler after +the server is started. + +``` +use trillium::{Conn, State, Init}; + +#[derive(Debug, Clone)] +struct MyDatabaseConnection(String); +impl MyDatabaseConnection { + async fn connect(uri: &str) -> std::io::Result { + Ok(Self(uri.into())) + } + async fn query(&self, query: &str) -> String { + format!("you queried `{}` against {}", query, &self.0) + } +} + +let mut handler = ( + Init::new(|mut info| async move { + let db = MyDatabaseConnection::connect("db://db").await.expect("1"); + info.with_state(db) + }), + |conn: Conn| async move { + dbg!(&conn); + let db = conn.shared_state::().expect("2"); + let response = db.query("select * from users limit 1").await; + conn.ok(response) + } +); + +use trillium_testing::prelude::*; + +let server_config = init(&mut handler); +assert_ok!( + get("/").with_server_config(server_config).on(&handler), + "you queried `select * from users limit 1` against db://db" +); + +``` +*/ +#[derive(Debug)] +pub struct Init(Option); + +impl Init +where + F: FnOnce(Info) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + /// Constructs a new Init handler with an async function that receives and returns [`Info`]. + #[must_use] + pub const fn new(init: F) -> Self { + Self(Some(init)) + } +} + +impl Handler for Init +where + F: FnOnce(Info) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + async fn run(&self, conn: Conn) -> Conn { + conn + } + + async fn init(&mut self, info: &mut Info) { + if let Some(init) = self.0.take() { + *info = init(mem::take(info)).await; + } else { + log::warn!("called init more than once"); + } + } +} + +/// alias for [`Init::new`] +pub const fn init(init: F) -> Init +where + F: FnOnce(Info) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + Init::new(init) +} diff --git a/trillium/src/lib.rs b/trillium/src/lib.rs index 3e4ff774b4..f142953416 100644 --- a/trillium/src/lib.rs +++ b/trillium/src/lib.rs @@ -60,3 +60,6 @@ pub use info::Info; mod boxed_handler; pub use boxed_handler::BoxedHandler; + +mod init; +pub use init::{init, Init}; diff --git a/trillium/src/shared_state.rs b/trillium/src/shared_state.rs new file mode 100644 index 0000000000..a3adfa93a2 --- /dev/null +++ b/trillium/src/shared_state.rs @@ -0,0 +1,36 @@ +use crate::{Handler, Info}; +use std::{any::type_name, borrow::Cow}; + +/// This handler populates a type into the immutable server-shared state type-set. Note that unlike +/// [`State`], this handler does not require [`Clone`], as the single allocation provided to the +/// constructor is held in an Arc and shared with every Conn. +/// +#[derive(Debug)] +pub struct SharedState(Option); +impl SharedState +where + T: Send + Sync + 'static, +{ + /// Constructs a new State handler from any `Clone` + `Send` + `Sync` + + /// `'static` + pub const fn new(t: T) -> Self { + Self(Some(t)) + } +} + +/// Constructs a new [`SharedState`] handler from any Send + Sync + +/// 'static. Alias for [`SharedState::new`] +#[allow(clippy::missing_const_for_fn)] +pub fn shared_state(t: T) -> SharedState { + SharedState::new(t) +} + +impl Handler for SharedState { + async fn init(&mut self, info: &mut Info) { + info.insert_state(self.0.take().unwrap()); + } + + fn name(&self) -> Cow<'static, str> { + format!("SharedState<{}>", type_name::()).into() + } +} diff --git a/trillium/src/state.rs b/trillium/src/state.rs index 7325748808..bd3873f436 100644 --- a/trillium/src/state.rs +++ b/trillium/src/state.rs @@ -1,5 +1,5 @@ use crate::{Conn, Handler}; -use std::fmt::{self, Debug, Formatter}; +use std::fmt::Debug; /** # A handler for sharing state across an application. @@ -55,28 +55,9 @@ with whatever cross thread synchronization mechanisms are appropriate for your application. There will be one clones of the contained T type in memory for each http connection, and any locks should be held as briefly as possible so as to minimize impact on other conns. - -**Stability note:** This is a common enough pattern that it currently -exists in the public api, but may be removed at some point for -simplicity. */ - -pub struct State(T); - -impl Debug for State { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_tuple("State").field(&self.0).finish() - } -} - -impl Default for State -where - T: Default + Clone + Send + Sync + 'static, -{ - fn default() -> Self { - Self::new(T::default()) - } -} +#[derive(Debug, Default)] +pub struct State(T); impl State where @@ -84,16 +65,14 @@ where { /// Constructs a new State handler from any `Clone` + `Send` + `Sync` + /// `'static` - #[allow(clippy::missing_const_for_fn)] - pub fn new(t: T) -> Self { + pub const fn new(t: T) -> Self { Self(t) } } /// Constructs a new [`State`] handler from any Clone + Send + Sync + /// 'static. Alias for [`State::new`] -#[allow(clippy::missing_const_for_fn)] -pub fn state(t: T) -> State { +pub const fn state(t: T) -> State { State::new(t) } diff --git a/trillium/tests/init.rs b/trillium/tests/init.rs new file mode 100644 index 0000000000..66216c1dca --- /dev/null +++ b/trillium/tests/init.rs @@ -0,0 +1,56 @@ +use std::io; + +use test_harness::test; +use trillium::Handler; +use trillium_client::Client; +use trillium_http::ServerConfig; +use trillium_testing::{harness, ServerConnector, TestResult}; + +async fn test_client(mut handler: impl Handler) -> Client { + let mut info = ServerConfig::default().into(); + handler.init(&mut info).await; + let connector = ServerConnector::new(handler).with_server_config(info.into()); + Client::new(connector).with_base("http://test.host") +} + +#[test(harness)] +async fn init_doctest() -> TestResult { + use trillium::{Conn, Init}; + + #[derive(Debug, Clone)] + struct MyDatabaseConnection(&'static str); + impl MyDatabaseConnection { + async fn connect(uri: &'static str) -> io::Result { + Ok(Self(uri)) + } + async fn query(&self, query: &str) -> String { + format!("you queried `{}` against {}", query, &self.0) + } + } + + let client = test_client(( + Init::new(|info| async move { + let db = MyDatabaseConnection::connect("mydatabase://...") + .await + .unwrap(); + info.with_state(db) + }), + |conn: Conn| async move { + let Some(db) = conn.shared_state::() else { + return conn.with_status(500); + }; + let response = db.query("select * from users limit 1").await; + conn.ok(response) + }, + )) + .await; + + let mut conn = client.get("/").await?; + + assert_eq!( + conn.response_body().read_string().await?, + "you queried `select * from users limit 1` against mydatabase://..." + ); + + Ok(()) +} diff --git a/trillium/tests/liveness.rs b/trillium/tests/liveness.rs index d6fa2b372e..864a899aa5 100644 --- a/trillium/tests/liveness.rs +++ b/trillium/tests/liveness.rs @@ -23,11 +23,8 @@ async fn infinitely_pending_task() -> TestResult { }); let info = handle.info().await; - - let url = format!("http://{}", info.listener_description()) - .parse() - .unwrap(); - let mut client = connector.connect(&url).await?; + let url = info.url().unwrap(); + let mut client = connector.connect(url).await?; client .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") @@ -48,7 +45,6 @@ async fn infinitely_pending_task() -> TestResult { #[test(harness)] async fn is_disconnected() -> TestResult { - let _ = env_logger::builder().is_test(true).try_init(); let connector = ArcedConnector::new(client_config()); let (delay_sender, delay_receiver) = async_channel::unbounded(); let (disconnected_sender, disconnected_receiver) = async_channel::unbounded(); @@ -71,10 +67,8 @@ async fn is_disconnected() -> TestResult { let info = handle.info().await; let runtime = handle.runtime(); - let url = format!("http://{}", info.listener_description()) - .parse() - .unwrap(); - let mut client = connector.connect(&url).await?; + let url = info.url().unwrap(); + let mut client = connector.connect(url).await?; client .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") @@ -88,7 +82,7 @@ async fn is_disconnected() -> TestResult { assert!(s.starts_with("HTTP/1.1 200 OK\r\n")); client.close().await?; - let mut client = connector.connect(&url).await?; + let mut client = connector.connect(url).await?; client .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") .await?; diff --git a/websockets/src/bidirectional_stream.rs b/websockets/src/bidirectional_stream.rs index 1ec60c0b8d..9ca77f6b57 100644 --- a/websockets/src/bidirectional_stream.rs +++ b/websockets/src/bidirectional_stream.rs @@ -23,7 +23,7 @@ impl Debug for BidirectionalStream { None => "None", }, ) - .field("outbound", &"..") + .field("outbound", &format_args!("..")) .finish() } } diff --git a/websockets/src/websocket_connection.rs b/websockets/src/websocket_connection.rs index 472fb9f35c..b141532859 100644 --- a/websockets/src/websocket_connection.rs +++ b/websockets/src/websocket_connection.rs @@ -10,11 +10,12 @@ use futures_util::{ use std::{ net::IpAddr, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use swansong::{Interrupt, Swansong}; use trillium::{Headers, Method, TypeSet, Upgrade}; -use trillium_http::{transport::BoxedTransport, type_set::entry::Entry}; +use trillium_http::{transport::BoxedTransport, type_set::entry::Entry, ServerConfig}; /** A struct that represents an specific websocket connection. @@ -34,7 +35,7 @@ pub struct WebSocketConn { method: Method, state: TypeSet, peer_ip: Option, - swansong: Swansong, + server_config: Arc, sink: SplitSink, stream: Option, } @@ -77,7 +78,7 @@ impl WebSocketConn { state, buffer, transport, - swansong, + server_config, peer_ip, .. } = upgrade; @@ -90,7 +91,7 @@ impl WebSocketConn { let (sink, stream) = wss.split(); let stream = Some(WStream { - stream: swansong.interrupt(stream), + stream: server_config.swansong().interrupt(stream), }); Self { @@ -101,13 +102,13 @@ impl WebSocketConn { peer_ip, sink, stream, - swansong, + server_config, } } /// retrieve a clone of the server's [`Swansong`] pub fn swansong(&self) -> Swansong { - self.swansong.clone() + self.server_config.swansong().clone() } /// close the websocket connection gracefully @@ -178,6 +179,12 @@ impl WebSocketConn { self.state.insert(state) } + /// Returns an [`Entry`] for the state typeset that can be used with functions like + /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others. + pub fn state_entry(&mut self) -> Entry<'_, T> { + self.state.entry() + } + /** take some type T out of the state set that has been accumulated by trillium handlers run on the [`trillium::Conn`] @@ -188,12 +195,6 @@ impl WebSocketConn { self.state.take() } - /// Returns an [`Entry`] for the state typeset that can be used with functions like - /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others. - pub fn state_entry(&mut self) -> Entry<'_, T> { - self.state.entry() - } - /// take the inbound Message stream from this conn pub fn take_inbound_stream(&mut self) -> Option> { self.stream.take()