Skip to content

Commit

Permalink
feat: Add option to set a request timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
threema-donat committed Apr 26, 2024
1 parent a3860ae commit 7186b08
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 37 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ring = { version = "0.17", features = ["std"], optional = true }
hyper-rustls = { version = "0.26.0", default-features = false, features = ["http2", "webpki-roots", "ring"] }
rustls-pemfile = "2.1.1"
rustls = "0.22.0"
tokio = { version = "1", features = ["time"] }

[dev-dependencies]
argparse = "0.2"
Expand Down
168 changes: 131 additions & 37 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::error::Error;
use crate::error::Error::ResponseError;
use crate::signer::Signer;
use tokio::time::timeout;

use crate::request::payload::PayloadLike;
use crate::response::Response;
Expand All @@ -20,6 +21,8 @@ use std::io::Read;
use std::time::Duration;
use std::{fmt, io};

const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 20;

type HyperConnector = HttpsConnector<HttpConnector>;

/// The APNs service endpoint to connect.
Expand Down Expand Up @@ -52,23 +55,96 @@ impl fmt::Display for Endpoint {
/// holds the response for handling.
#[derive(Debug, Clone)]
pub struct Client {
endpoint: Endpoint,
signer: Option<Signer>,
options: ConnectionOptions,
http_client: HttpClient<HyperConnector, BoxBody<Bytes, Infallible>>,
}

impl Client {
fn new(connector: HyperConnector, signer: Option<Signer>, endpoint: Endpoint) -> Client {
let mut builder = HttpClient::builder(TokioExecutor::new());
builder.pool_idle_timeout(Some(Duration::from_secs(600)));
builder.http2_only(true);
/// Uses [`Endpoint::Production`] by default.
#[derive(Debug, Clone)]
pub struct ClientOptions {
/// The timeout of the HTTP requests
pub request_timeout_secs: Option<u64>,
/// The timeout for idle sockets being kept alive
pub pool_idle_timeout_secs: Option<u64>,
/// The endpoint where the requests are sent to
pub endpoint: Endpoint,
/// See [`crate::signer::Signer`]
pub signer: Option<Signer>,
}

impl Default for ClientOptions {
fn default() -> Self {
Self {
pool_idle_timeout_secs: Some(600),
request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS),
endpoint: Endpoint::Production,
signer: None,
}
}
}

Client {
http_client: builder.build(connector),
impl ClientOptions {
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
..Default::default()
}
}

pub fn with_signer(mut self, signer: Signer) -> Self {
self.signer = Some(signer);
self
}

pub fn with_request_timeout(mut self, seconds: u64) -> Self {
self.request_timeout_secs = Some(seconds);
self
}

pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self {
self.pool_idle_timeout_secs = Some(seconds);
self
}
}

#[derive(Debug, Clone)]
struct ConnectionOptions {
endpoint: Endpoint,
request_timeout: Duration,
signer: Option<Signer>,
}

impl From<ClientOptions> for ConnectionOptions {
fn from(value: ClientOptions) -> Self {
let ClientOptions {
endpoint,
pool_idle_timeout_secs: _,
signer,
request_timeout_secs,
} = value;
let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS));
Self {
endpoint,
request_timeout,
signer,
}
}
}

impl Client {
/// If `options` is not set, a default using [`Endpoint::Production`] will
/// be initialized.
fn new(connector: HyperConnector, options: Option<ClientOptions>) -> Client {
let options = options.unwrap_or_default();
let http_client = HttpClient::builder(TokioExecutor::new())
.pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs))
.http2_only(true)
.build(connector);

let options = options.into();

Client { http_client, options }
}

/// Create a connection to APNs using the provider client certificate which
/// you obtain from your [Apple developer
Expand All @@ -89,7 +165,7 @@ impl Client {
};
let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?;

Ok(Self::new(connector, None, endpoint))
Ok(Self::new(connector, Some(ClientOptions::new(endpoint))))
}

/// Create a connection to APNs using the raw PEM-formatted certificate and
Expand All @@ -98,7 +174,7 @@ impl Client {
pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result<Client, Error> {
let connector = client_cert_connector(cert_pem, key_pem)?;

Ok(Self::new(connector, None, endpoint))
Ok(Self::new(connector, Some(ClientOptions::new(endpoint))))
}

/// Create a connection to APNs using system certificates, signing every
Expand All @@ -113,9 +189,16 @@ impl Client {
{
let connector = default_connector();
let signature_ttl = Duration::from_secs(60 * 55);
let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?;
let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?);

Ok(Self::new(connector, Some(signer), endpoint))
Ok(Self::new(
connector,
Some(ClientOptions {
endpoint,
signer,
..Default::default()
}),
))
}

/// Send a notification payload.
Expand All @@ -126,7 +209,11 @@ impl Client {
let request = self.build_request(payload);
let requesting = self.http_client.request(request);

let response = requesting.await?;
let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else {
return Err(Error::RequestTimeout(self.options.request_timeout.as_secs()));
};

let response = response_result?;

let apns_id = response
.headers()
Expand All @@ -153,7 +240,11 @@ impl Client {
}

fn build_request<T: PayloadLike>(&self, payload: T) -> hyper::Request<BoxBody<Bytes, Infallible>> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());
let path = format!(
"https://{}/3/device/{}",
self.options.endpoint,
payload.get_device_token()
);

let mut builder = hyper::Request::builder()
.uri(&path)
Expand All @@ -179,7 +270,7 @@ impl Client {
if let Some(apns_topic) = options.apns_topic {
builder = builder.header("apns-topic", apns_topic.as_bytes());
}
if let Some(ref signer) = self.signer {
if let Some(ref signer) = self.options.signer {
let auth = signer
.with_signature(|signature| format!("Bearer {}", signature))
.unwrap();
Expand Down Expand Up @@ -246,7 +337,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_production_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let uri = format!("{}", request.uri());

Expand All @@ -257,7 +348,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_sandbox_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Sandbox);
let client = Client::new(default_connector(), Some(ClientOptions::new(Endpoint::Sandbox)));
let request = client.build_request(payload);
let uri = format!("{}", request.uri());

Expand All @@ -268,7 +359,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_method() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);

assert_eq!(&Method::POST, request.method());
Expand All @@ -278,7 +369,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
Expand All @@ -288,7 +379,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_length() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload.clone());
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();
Expand All @@ -300,7 +391,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_authorization_with_no_signer() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);

assert_eq!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -318,7 +409,10 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), Some(signer), Endpoint::Production);
let client = Client::new(
default_connector(),
Some(ClientOptions::new(Endpoint::Production).with_signer(signer)),
);
let request = client.build_request(payload);

assert_ne!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -332,7 +426,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
..Default::default()
};
let payload = builder.build("a_test_id", options);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_push_type = request.headers().get("apns-push-type").unwrap();

Expand All @@ -343,7 +437,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_with_default_priority() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority");

Expand All @@ -362,7 +456,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -381,7 +475,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -394,7 +488,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_id = request.headers().get("apns-id");

Expand All @@ -413,7 +507,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_id = request.headers().get("apns-id").unwrap();

Expand All @@ -426,7 +520,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_expiration = request.headers().get("apns-expiration");

Expand All @@ -445,7 +539,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_expiration = request.headers().get("apns-expiration").unwrap();

Expand All @@ -458,7 +552,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_collapse_id = request.headers().get("apns-collapse-id");

Expand All @@ -477,7 +571,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

Expand All @@ -490,7 +584,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_topic = request.headers().get("apns-topic");

Expand All @@ -509,7 +603,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload);
let apns_topic = request.headers().get("apns-topic").unwrap();

Expand All @@ -520,7 +614,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
async fn test_request_body() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::new(default_connector(), None);
let request = client.build_request(payload.clone());

let body = request.into_body().collect().await.unwrap().to_bytes();
Expand All @@ -538,7 +632,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let cert: Vec<u8> = include_str!("../test_cert/test.crt").bytes().collect();

let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?;
assert!(c.signer.is_none());
assert!(c.options.signer.is_none());
Ok(())
}
}
Loading

0 comments on commit 7186b08

Please sign in to comment.