diff --git a/src/server/mod.rs b/src/server/mod.rs index 1cfa4e9..71c7787 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -162,6 +162,8 @@ where let bytes_written = io::copy(&mut encoder, &mut self.io).await?; log::trace!("wrote {} response bytes", bytes_written); + async_std::task::sleep(Duration::from_millis(1)).await; + let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?; log::trace!( "discarded {} unread request body bytes", diff --git a/tests/accept.rs b/tests/accept.rs index 92283a8..a57ff0c 100644 --- a/tests/accept.rs +++ b/tests/accept.rs @@ -1,7 +1,10 @@ mod test_utils; mod accept { + use std::time::Duration; + use super::test_utils::TestServer; use async_h1::{client::Encoder, server::ConnectionStatus}; + use async_std::future::timeout; use async_std::io::{self, prelude::WriteExt, Cursor}; use http_types::{headers::CONNECTION, Body, Request, Response, Result}; @@ -17,7 +20,7 @@ mod accept { let content_length = 10; let request_str = format!( - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n", + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", content_length, std::str::from_utf8(&vec![b'|'; content_length]).unwrap() ); @@ -33,6 +36,39 @@ mod accept { Ok(()) } + #[async_std::test] + async fn pipelined() -> Result<()> { + let mut server = TestServer::new(|req| async { + let mut response = Response::new(200); + let len = req.len(); + response.set_body(Body::from_reader(req, len)); + Ok(response) + }); + + let content_length = 10; + + let request_str = format!( + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", + content_length, + std::str::from_utf8(&vec![b'|'; content_length]).unwrap() + ); + + server.write_all(request_str.as_bytes()).await?; + server.write_all(request_str.as_bytes()).await?; + assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) + } + #[async_std::test] async fn request_close() -> Result<()> { let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); @@ -74,7 +110,7 @@ mod accept { let content_length = 10; let request_str = format!( - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n", + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", content_length, std::str::from_utf8(&vec![b'|'; content_length]).unwrap() ); @@ -130,7 +166,7 @@ mod accept { let content_length = 10000; let request_str = format!( - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n", + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", content_length, std::str::from_utf8(&vec![b'|'; content_length]).unwrap() ); diff --git a/tests/continue.rs b/tests/continue.rs index 933fbfe..ad54ea8 100644 --- a/tests/continue.rs +++ b/tests/continue.rs @@ -1,9 +1,12 @@ mod test_utils; +use async_h1::server::ConnectionStatus; +use async_std::future::timeout; +use async_std::io::BufReader; use async_std::{io, prelude::*, task}; -use http_types::Result; +use http_types::{Response, Result}; use std::time::Duration; -use test_utils::TestIO; +use test_utils::{TestIO, TestServer}; const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\ Host: example.com\r\n\ @@ -52,3 +55,183 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> { Ok(()) } + +#[async_std::test] +async fn test_accept_unread_body() -> Result<()> { + let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_echo_server() -> Result<()> { + let mut server = TestServer::new(|mut req| async move { + let mut resp = Response::new(200); + resp.set_body(req.take_body()); + Ok(resp) + }); + + server.write_all(REQUEST_WITH_EXPECT).await?; + server.write_all(b"0123456789").await?; + assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); + + task::sleep(SLEEP_DURATION).await; // wait for "continue" to be sent + + server.close(); + + assert!(server + .client + .read + .to_string() + .starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_delayed_read() -> Result<()> { + let mut server = TestServer::new(|mut req| async move { + let mut body = req.take_body(); + task::spawn(async move { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.unwrap(); + }); + Ok(Response::new(200)) + }); + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + server.write_all(b"0123456789").await?; + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + server.write_all(b"0123456789").await?; + + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_accept_fast_unread_sequential_requests() -> Result<()> { + let mut server = TestServer::new(|_| async move { Ok(Response::new(200)) }); + let mut client = server.client.clone(); + + task::spawn(async move { + let mut reader = BufReader::new(client.clone()); + for _ in 0..10 { + let mut buf = String::new(); + client.write_all(REQUEST_WITH_EXPECT).await.unwrap(); + + while !buf.ends_with("\r\n\r\n") { + reader.read_line(&mut buf).await.unwrap(); + } + + assert!(buf.starts_with("HTTP/1.1 200 OK\r\n")); + } + client.close(); + }); + + for _ in 0..10 { + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + } + + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_accept_partial_read_sequential_requests() -> Result<()> { + const LARGE_REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 1000\r\n\ + Expect: 100-continue\r\n\r\n"; + + let mut server = TestServer::new(|mut req| async move { + let mut body = req.take_body(); + let mut buf = [0]; + body.read(&mut buf).await.unwrap(); + Ok(Response::new(200)) + }); + let mut client = server.client.clone(); + + task::spawn(async move { + let mut reader = BufReader::new(client.clone()); + for _ in 0..10 { + let mut buf = String::new(); + client.write_all(LARGE_REQUEST_WITH_EXPECT).await.unwrap(); + + // Wait for body to be requested + while !buf.ends_with("\r\n\r\n") { + reader.read_line(&mut buf).await.unwrap(); + } + assert!(buf.starts_with("HTTP/1.1 100 Continue\r\n")); + + // Write body + for _ in 0..100 { + client.write_all(b"0123456789").await.unwrap(); + } + + // Wait for response + buf.clear(); + while !buf.ends_with("\r\n\r\n") { + reader.read_line(&mut buf).await.unwrap(); + } + + assert!(buf.starts_with("HTTP/1.1 200 OK\r\n")); + } + client.close(); + }); + + for _ in 0..10 { + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + } + + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::Close + ); + + assert!(server.all_read()); + + Ok(()) +} diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 8194590..034d4cd 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -19,7 +19,7 @@ use async_dup::Arc; pub struct TestServer { server: Server, #[pin] - client: TestIO, + pub(crate) client: TestIO, } impl TestServer