diff --git a/src/endpoints/workerskv/write_bulk.rs b/src/endpoints/workerskv/write_bulk.rs index 5812a82e..23020bda 100644 --- a/src/endpoints/workerskv/write_bulk.rs +++ b/src/endpoints/workerskv/write_bulk.rs @@ -1,5 +1,6 @@ -use crate::framework::endpoint::{Endpoint, Method}; - +use crate::framework::{response::ApiFailure, response::ApiErrors, response::ApiError, endpoint::{Endpoint, Method}}; +use reqwest; +use std::collections::HashMap; /// Write Key-Value Pairs in Bulk /// Writes multiple key-value pairs to Workers KV at once. /// A 404 is returned if a write action is for a namespace ID the account doesn't have. @@ -23,7 +24,37 @@ impl<'a> Endpoint<(), (), Vec> for WriteBulk<'a> { fn body(&self) -> Option> { Some(self.bulk_key_value_pairs.clone()) } - // default content-type is already application/json + fn validate(&self) -> Result<(), ApiFailure> { + if let Some(body) = self.body() { + // this matches the serialization in HttpApiClient + let json = serde_json::to_string(&body).map_err(|e| + ApiFailure::Error( + reqwest::StatusCode::BAD_REQUEST, + ApiErrors { + errors: vec![ApiError { + code: 400, + message: format!("request body is malformed, failed json serialization: {}", e), + other: HashMap::new(), + }], + other: HashMap::new(), + }))?; + + if json.len() >= 100_000_000 { + return Err(ApiFailure::Error( + reqwest::StatusCode::PAYLOAD_TOO_LARGE, + ApiErrors { + errors: vec![ApiError { + code: 413, + message: "request payload too large, must be less than 100MB".to_owned(), + other: HashMap::new(), + }], + other: HashMap::new(), + }, + )); + } + } + Ok(()) + } } #[serde_with::skip_serializing_none] @@ -35,3 +66,54 @@ pub struct KeyValuePair { pub expiration_ttl: Option, pub base64: Option, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn write_bulk_validator_failure() { + let write_bulk_endpoint = WriteBulk{ + account_identifier: "test_account_id", + namespace_identifier: "test_namespace", + bulk_key_value_pairs: vec![ + KeyValuePair { + key: "test".to_owned(), + value: "X".repeat(100_000_000), + expiration: None, + expiration_ttl: None, + base64: None + } + ] + }; + + match write_bulk_endpoint.validate() { + Ok(_) => assert!(false, "payload too large and validator passed incorrectly"), + Err(_) => assert!(true) + } + } + + #[test] + fn write_bulk_validator_success() { + let write_bulk_endpoint = WriteBulk{ + account_identifier: "test_account_id", + namespace_identifier: "test_namespace", + bulk_key_value_pairs: vec![ + KeyValuePair { + key: "test".to_owned(), + // max is 99,999,972 chars for the val + // the other 28 chars are taken by the key, property names, and json formatting chars + value: "x".repeat(99_999_950), + expiration: None, + expiration_ttl: None, + base64: None + } + ] + }; + + match write_bulk_endpoint.validate() { + Ok(_) => assert!(true), + Err(_) => assert!(false, "payload within bounds and validator failed incorrectly") + } + } +} diff --git a/src/framework/endpoint.rs b/src/framework/endpoint.rs index 80712216..7404a836 100644 --- a/src/framework/endpoint.rs +++ b/src/framework/endpoint.rs @@ -1,5 +1,7 @@ use crate::framework::response::ApiResult; +use crate::framework::response::ApiFailure; use crate::framework::Environment; + use serde::Serialize; use url::Url; @@ -11,6 +13,7 @@ pub enum Method { Patch, } + pub trait Endpoint where ResultType: ApiResult, @@ -25,6 +28,10 @@ where fn body(&self) -> Option { None } + /// Some endpoints dont need to validate. That's OK. + fn validate(&self) -> Result<(), ApiFailure> { + Ok(()) + } fn url(&self, environment: &Environment) -> Url { Url::from(environment).join(&self.path()).unwrap() } diff --git a/src/framework/mod.rs b/src/framework/mod.rs index 528494dd..ccd58e37 100644 --- a/src/framework/mod.rs +++ b/src/framework/mod.rs @@ -80,6 +80,8 @@ impl<'a> ApiClient for HttpApiClient { } } + endpoint.validate()?; + // Build the request let mut request = self .http_client @@ -97,7 +99,6 @@ impl<'a> ApiClient for HttpApiClient { request = request.auth(&self.credentials); let response = request.send()?; - map_api_response(response) } }