Skip to content

Commit

Permalink
use a concrete error type in async-h1
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Jan 25, 2021
1 parent d30fa6d commit 695ba39
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 95 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ log = "0.4.11"
pin-project = "1.0.2"
async-channel = "1.5.1"
async-dup = "1.2.2"
thiserror = "1.0.22"

[dev-dependencies]
pretty_assertions = "0.6.1"
Expand Down
59 changes: 33 additions & 26 deletions src/client/decode.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use async_std::io::{BufReader, Read};
use async_std::prelude::*;
use http_types::{ensure, ensure_eq, format_err};
use http_types::content::ContentLength;
use http_types::{
headers::{CONTENT_LENGTH, DATE, TRANSFER_ENCODING},
headers::{DATE, TRANSFER_ENCODING},
Body, Response, StatusCode,
};

use std::convert::TryFrom;

use crate::chunked::ChunkedDecoder;
use crate::date::fmt_http_date;
use crate::{chunked::ChunkedDecoder, Error};
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};

const CR: u8 = b'\r';
const LF: u8 = b'\n';

/// Decode an HTTP response on the client.
pub async fn decode<R>(reader: R) -> http_types::Result<Response>
pub async fn decode<R>(reader: R) -> crate::Result<Option<Response>>
where
R: Read + Unpin + Send + Sync + 'static,
{
Expand All @@ -29,13 +29,14 @@ where
loop {
let bytes_read = reader.read_until(LF, &mut buf).await?;
// No more bytes are yielded from the stream.
assert!(bytes_read != 0, "Empty response"); // TODO: ensure?
if bytes_read == 0 {
return Ok(None);
}

// Prevent CWE-400 DDOS with large HTTP Headers.
ensure!(
buf.len() < MAX_HEAD_LENGTH,
"Head byte length should be less than 8kb"
);
if buf.len() >= MAX_HEAD_LENGTH {
return Err(Error::HeadersTooLong);
}

// We've hit the end delimiter of the stream.
let idx = buf.len() - 1;
Expand All @@ -49,17 +50,23 @@ where

// Convert our header buf into an httparse instance, and validate.
let status = httparse_res.parse(&buf)?;
ensure!(!status.is_partial(), "Malformed HTTP head");
if status.is_partial() {
return Err(Error::PartialHead);
}

let code = httparse_res.code;
let code = code.ok_or_else(|| format_err!("No status code found"))?;
let code = httparse_res.code.ok_or(Error::MissingStatusCode)?;

// Convert httparse headers + body into a `http_types::Response` type.
let version = httparse_res.version;
let version = version.ok_or_else(|| format_err!("No version found"))?;
ensure_eq!(version, 1, "Unsupported HTTP version");
let version = httparse_res.version.ok_or(Error::MissingVersion)?;

if version != 1 {
return Err(Error::UnsupportedVersion(version));
}

let status_code =
StatusCode::try_from(code).map_err(|_| Error::UnrecognizedStatusCode(code))?;
let mut res = Response::new(status_code);

let mut res = Response::new(StatusCode::try_from(code)?);
for header in httparse_res.headers.iter() {
res.append_header(header.name, std::str::from_utf8(header.value)?);
}
Expand All @@ -69,13 +76,13 @@ where
res.insert_header(DATE, &format!("date: {}\r\n", date)[..]);
}

let content_length = res.header(CONTENT_LENGTH);
let content_length =
ContentLength::from_headers(&res).map_err(|_| Error::MalformedHeader("content-length"))?;
let transfer_encoding = res.header(TRANSFER_ENCODING);

ensure!(
content_length.is_none() || transfer_encoding.is_none(),
"Unexpected Content-Length header"
);
if content_length.is_some() && transfer_encoding.is_some() {
return Err(Error::UnexpectedHeader("content-length"));
}

if let Some(encoding) = transfer_encoding {
if encoding.last().as_str() == "chunked" {
Expand All @@ -84,16 +91,16 @@ where
res.set_body(Body::from_reader(reader, None));

// Return the response.
return Ok(res);
return Ok(Some(res));
}
}

// Check for Content-Length.
if let Some(len) = content_length {
let len = len.last().as_str().parse::<usize>()?;
res.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
if let Some(content_length) = content_length {
let len = content_length.len();
res.set_body(Body::from_reader(reader.take(len), Some(len as usize)));
}

// Return the response.
Ok(res)
Ok(Some(res))
}
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use decode::decode;
pub use encode::Encoder;

/// Opens an HTTP/1.1 connection to a remote host.
pub async fn connect<RW>(mut stream: RW, req: Request) -> http_types::Result<Response>
pub async fn connect<RW>(mut stream: RW, req: Request) -> crate::Result<Option<Response>>
where
RW: Read + Write + Send + Sync + Unpin + 'static,
{
Expand Down
84 changes: 84 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::str::Utf8Error;

use http_types::url;
use thiserror::Error;

/// Concrete errors that occur within async-h1
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum Error {
/// [`std::io::Error`]
#[error(transparent)]
IO(#[from] std::io::Error),

/// [`url::ParseError`]
#[error(transparent)]
Url(#[from] url::ParseError),

/// this error describes a malformed request with a path that does
/// not start with / or http:// or https://
#[error("unexpected uri format")]
UnexpectedURIFormat,

/// this error describes a http 1.1 request that is missing a Host
/// header
#[error("mandatory host header missing")]
HostHeaderMissing,

/// this error describes a request that does not specify a path
#[error("request path missing")]
RequestPathMissing,

/// [`httparse::Error`]
#[error(transparent)]
Httparse(#[from] httparse::Error),

/// an incomplete http head
#[error("partial http head")]
PartialHead,

/// we were unable to parse a header
#[error("malformed http header {0}")]
MalformedHeader(&'static str),

/// async-h1 doesn't speak this http version
/// this error is deprecated
#[error("unsupported http version 1.{0}")]
UnsupportedVersion(u8),

/// we were unable to parse this http method
#[error("unsupported http method {0}")]
UnrecognizedMethod(String),

/// this request did not have a method
#[error("missing method")]
MissingMethod,

/// this request did not have a status code
#[error("missing status code")]
MissingStatusCode,

/// we were unable to parse this http method
#[error("unrecognized http status code {0}")]
UnrecognizedStatusCode(u16),

/// this request did not have a version, but we expect one
/// this error is deprecated
#[error("missing version")]
MissingVersion,

/// we expected utf8, but there was an encoding error
#[error(transparent)]
EncodingError(#[from] Utf8Error),

/// we received a header that does not make sense in context
#[error("unexpected header: {0}")]
UnexpectedHeader(&'static str),

/// for security reasons, we do not allow request headers beyond 8kb.
#[error("Head byte length should be less than 8kb")]
HeadersTooLong,
}

/// this crate's result type
pub type Result<T> = std::result::Result<T, Error>;
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ use async_std::io::Cursor;
use body_encoder::BodyEncoder;
pub use client::connect;
pub use server::{accept, accept_with_opts, ServerOptions};
mod error;
pub use error::{Error, Result};

#[derive(Debug)]
pub(crate) enum EncoderState {
Expand Down
Loading

0 comments on commit 695ba39

Please sign in to comment.