diff --git a/mountpoint-s3-client/src/object_client.rs b/mountpoint-s3-client/src/object_client.rs index d6b501e34..cc9c51ed2 100644 --- a/mountpoint-s3-client/src/object_client.rs +++ b/mountpoint-s3-client/src/object_client.rs @@ -274,6 +274,11 @@ pub struct PutObjectParams { pub trailing_checksums: bool, /// Storage class to be used when creating new S3 object pub storage_class: Option, + /// The server-side encryption algorithm to be used for this object in Amazon S3 (for example, AES256, aws:kms, aws:kms:dsse) + pub server_side_encryption: Option, + /// If `server_side_encryption` has a valid value of aws:kms or aws:kms:dsse, this value may be used to specify AWS KMS key ID to be used + /// when creating new S3 object + pub ssekms_key_id: Option, } impl PutObjectParams { @@ -293,6 +298,18 @@ impl PutObjectParams { self.storage_class = Some(value); self } + + /// Set server-side encryption type. + pub fn server_side_encryption(mut self, value: Option) -> Self { + self.server_side_encryption = value; + self + } + + /// Set KMS key ID to be used for server-side encryption. + pub fn ssekms_key_id(mut self, value: Option) -> Self { + self.ssekms_key_id = value; + self + } } /// Info for the caller to review before an upload completes. diff --git a/mountpoint-s3-client/src/s3_crt_client.rs b/mountpoint-s3-client/src/s3_crt_client.rs index d1216e36c..745bcc878 100644 --- a/mountpoint-s3-client/src/s3_crt_client.rs +++ b/mountpoint-s3-client/src/s3_crt_client.rs @@ -570,7 +570,7 @@ impl S3CrtClientInner { on_error: impl FnOnce(&MetaRequestResult) -> Option + Send + 'static, ) -> Result, E>, S3RequestError> { let options = Self::new_meta_request_options(message, request_type); - self.make_simple_http_request_from_options(options, request_span, on_error) + self.make_simple_http_request_from_options(options, request_span, on_error, |_, _| ()) } /// Make an HTTP request using this S3 client that returns the body on success or invokes the @@ -580,6 +580,7 @@ impl S3CrtClientInner { options: MetaRequestOptions, request_span: Span, on_error: impl FnOnce(&MetaRequestResult) -> Option + Send + 'static, + on_headers: impl FnMut(&Headers, i32) + Send + 'static, ) -> Result, E>, S3RequestError> { // Accumulate the body of the response into this Vec let body: Arc>> = Default::default(); @@ -588,7 +589,7 @@ impl S3CrtClientInner { self.make_meta_request_from_options( options, request_span, - |_, _| (), + on_headers, move |offset, data| { let mut body = body_clone.lock().unwrap(); assert_eq!(offset as usize, body.len()); diff --git a/mountpoint-s3-client/src/s3_crt_client/put_object.rs b/mountpoint-s3-client/src/s3_crt_client/put_object.rs index 7de568e05..0421732df 100644 --- a/mountpoint-s3-client/src/s3_crt_client/put_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/put_object.rs @@ -4,13 +4,16 @@ use std::time::Instant; use crate::object_client::{ObjectClientResult, PutObjectError, PutObjectParams, PutObjectRequest, PutObjectResult}; use crate::s3_crt_client::{emit_throughput_metric, S3CrtClient, S3RequestError}; use async_trait::async_trait; -use mountpoint_s3_crt::http::request_response::Header; +use mountpoint_s3_crt::http::request_response::{Header, Headers}; use mountpoint_s3_crt::io::async_stream::{self, AsyncStreamWriter}; use mountpoint_s3_crt::s3::client::{ChecksumConfig, MetaRequestType, UploadReview}; use tracing::error; use super::{S3CrtClientInner, S3HttpRequest}; +const SSE_TYPE_HEADER_NAME: &str = "x-amz-server-side-encryption"; +const SSE_KEY_ID_HEADER_NAME: &str = "x-amz-server-side-encryption-aws-kms-key-id"; + impl S3CrtClient { pub(super) async fn put_object( &self, @@ -45,12 +48,29 @@ impl S3CrtClient { .set_header(&Header::new("x-amz-storage-class", storage_class)) .map_err(S3RequestError::construction_failure)?; } - + if let Some(sse) = params.server_side_encryption.as_ref() { + message + .set_header(&Header::new(SSE_TYPE_HEADER_NAME, sse)) + .map_err(S3RequestError::construction_failure)?; + } + if let Some(key_id) = params.ssekms_key_id.as_ref() { + message + .set_header(&Header::new(SSE_KEY_ID_HEADER_NAME, key_id)) + .map_err(S3RequestError::construction_failure)?; + } + // Variable `response_headers` will be accessed from different threads: from CRT thread which executes `on_headers` callback + // and from our thread which executes `review_and_complete`. Callback `on_headers` is guaranteed to finish before this + // variable is accessed in `review_and_complete` (see `S3HttpRequest::poll` implementation). + let response_headers: Arc>> = Default::default(); + let response_headers_writer = response_headers.clone(); + let on_headers = move |headers: &Headers, _: i32| { + *response_headers_writer.lock().unwrap() = Some(headers.clone()); + }; let mut options = S3CrtClientInner::new_meta_request_options(message, MetaRequestType::PutObject); options.on_upload_review(move |review| callback.invoke(review)); let body = self .inner - .make_simple_http_request_from_options(options, span, |_| None)?; + .make_simple_http_request_from_options(options, span, |_| None, on_headers)?; Ok(S3PutObjectRequest { body, @@ -58,6 +78,9 @@ impl S3CrtClient { review_callback, start_time: Instant::now(), total_bytes: 0, + response_headers, + server_side_encryption: params.server_side_encryption.clone(), + ssekms_key_id: params.ssekms_key_id.clone(), }) } } @@ -106,6 +129,35 @@ pub struct S3PutObjectRequest { review_callback: ReviewCallbackBox, start_time: Instant, total_bytes: u64, + /// Headers of the CompleteMultipartUpload response, available after the request was finished + response_headers: Arc>>, + /// Server-side encryption type which is expected to be found in response_headers + server_side_encryption: Option, + /// Server-side encryption KMS key ID which is expected to be found in response_headers + ssekms_key_id: Option, +} + +/// If non empty `server_side_encryption` or `ssekms_key_id` were used, this function checks headers +/// of the CompleteMultipartUpload response to contain the expected values +fn check_response_headers(response_headers: &Headers, expected_sse: Option<&str>, expected_key_id: Option<&str>) { + if let Some(sse_type) = expected_sse { + let actual_header = response_headers.get(SSE_TYPE_HEADER_NAME).ok(); + let actual_value = actual_header.as_ref().and_then(|header| header.value().to_str()); + assert_eq!( + actual_value, + Some(sse_type), + "SSE type provided in CompleteMultipartUpload response does not match the requested value", + ); + } + if let Some(sse_key_id) = expected_key_id { + let actual_header = response_headers.get(SSE_KEY_ID_HEADER_NAME).ok(); + let actual_value = actual_header.as_ref().and_then(|header| header.value().to_str()); + assert_eq!( + actual_value, + Some(sse_key_id), + "SSE KMS key ID provided in CompleteMultipartUpload response does not match the requested value", + ); + } } #[cfg_attr(not(docs_rs), async_trait)] @@ -124,6 +176,9 @@ impl PutObjectRequest for S3PutObjectRequest { self.review_and_complete(|_| true).await } + /// Note: this function will panic if an SSE was requested to be applied to the object + /// and we failed to check that this actually happened. This may be caused by a bug in + /// CRT code or HTTP headers being corrupted in transit between us and the S3 server. async fn review_and_complete( mut self, review_callback: impl FnOnce(UploadReview) -> bool + Send + 'static, @@ -137,11 +192,68 @@ impl PutObjectRequest for S3PutObjectRequest { self.body }; - let result = body.await; + let _ = body.await?; let elapsed = self.start_time.elapsed(); emit_throughput_metric(self.total_bytes, elapsed, "put_object"); - result.map(|_| PutObjectResult {}) + check_response_headers( + self.response_headers + .lock() + .expect("must be able to acquire headers lock") + .as_ref() + .expect("PUT response headers must be available at this point"), + self.server_side_encryption.as_deref(), + self.ssekms_key_id.as_deref(), + ); + Ok(PutObjectResult {}) + } +} + +#[cfg(test)] +mod tests { + use super::{check_response_headers, Header, Headers, SSE_KEY_ID_HEADER_NAME, SSE_TYPE_HEADER_NAME}; + use mountpoint_s3_crt::common::allocator::Allocator; + use test_case::test_case; + + #[test_case(Some("sse:kms"), Some("some_key_alias"))] + #[test_case(Some("sse:kms:dsse"), Some("some_key_alias"))] + #[test_case(Some("sse:kms"), None)] + #[test_case(None, None)] + fn test_check_headers_ok(sse_type: Option<&str>, sse_kms_key_id: Option<&str>) { + let mut headers = Headers::new(&Allocator::default()).unwrap(); + if let Some(sse_type) = sse_type { + let header = Header::new(SSE_TYPE_HEADER_NAME, sse_type); + headers.add_header(&header).unwrap(); + } + if let Some(sse_kms_key_id) = sse_kms_key_id { + let header = Header::new(SSE_KEY_ID_HEADER_NAME, sse_kms_key_id); + headers.add_header(&header).unwrap(); + } + check_response_headers(&headers, sse_type, sse_kms_key_id); + } + + #[test] + #[should_panic( + expected = "SSE type provided in CompleteMultipartUpload response does not match the requested value" + )] + fn test_check_headers_bad_sse_type() { + let mut headers = Headers::new(&Allocator::default()).unwrap(); + let header = Header::new(SSE_TYPE_HEADER_NAME, "wrong"); + headers.add_header(&header).unwrap(); + let header = Header::new(SSE_KEY_ID_HEADER_NAME, "some_key_alias"); + headers.add_header(&header).unwrap(); + check_response_headers(&headers, Some("sse:kms"), Some("some_key_alias")); + } + + #[test] + #[should_panic( + expected = "SSE KMS key ID provided in CompleteMultipartUpload response does not match the requested value" + )] + fn test_check_headers_bad_sse_key() { + let mut headers = Headers::new(&Allocator::default()).unwrap(); + let header = Header::new(SSE_TYPE_HEADER_NAME, "sse:kms"); + headers.add_header(&header).unwrap(); + check_response_headers(&headers, Some("sse:kms"), Some("some_key_alias")); } } diff --git a/mountpoint-s3-client/tests/common/mod.rs b/mountpoint-s3-client/tests/common/mod.rs index 6eeef6a30..768690ee8 100644 --- a/mountpoint-s3-client/tests/common/mod.rs +++ b/mountpoint-s3-client/tests/common/mod.rs @@ -47,6 +47,10 @@ pub fn get_test_bucket() -> String { } } +pub fn get_test_kms_key_id() -> String { + std::env::var("KMS_TEST_KEY_ID").expect("Set KMS_TEST_KEY_ID to run integration tests") +} + pub fn get_test_client() -> S3CrtClient { let endpoint_config = EndpointConfig::new(&get_test_region()); S3CrtClient::new(S3ClientConfig::new().endpoint_config(endpoint_config)).expect("could not create test client") @@ -190,6 +194,7 @@ macro_rules! object_client_test { mod $test_fn_identifier { use super::$test_fn_identifier; use mountpoint_s3_client::mock_client::{MockClient, MockClientConfig}; + use mountpoint_s3_client::types::PutObjectParams; use $crate::{get_test_bucket_and_prefix, get_test_client}; #[tokio::test] @@ -202,7 +207,8 @@ macro_rules! object_client_test { unordered_list_seed: None, }); - $test_fn_identifier(&client, &bucket, &prefix).await; + let key = format!("{prefix}hello"); + $test_fn_identifier(&client, &bucket, &key, PutObjectParams::new()).await; } #[tokio::test] @@ -211,7 +217,8 @@ macro_rules! object_client_test { let client = get_test_client(); - $test_fn_identifier(&client, &bucket, &prefix).await; + let key = format!("{prefix}hello"); + $test_fn_identifier(&client, &bucket, &key, PutObjectParams::new()).await; } } }; diff --git a/mountpoint-s3-client/tests/put_object.rs b/mountpoint-s3-client/tests/put_object.rs index c1f068ca0..417a231e8 100644 --- a/mountpoint-s3-client/tests/put_object.rs +++ b/mountpoint-s3-client/tests/put_object.rs @@ -15,16 +15,14 @@ use test_case::test_case; // Simple test for PUT object. Puts a single, small object as a single part and checks that the // contents are correct with a GET. -async fn test_put_object(client: &impl ObjectClient, bucket: &str, prefix: &str) { +async fn test_put_object(client: &impl ObjectClient, bucket: &str, key: &str, request_params: PutObjectParams) { let mut rng = rand::thread_rng(); - let key = format!("{prefix}hello"); - let mut contents = vec![0u8; 32]; rng.fill(&mut contents[..]); let mut request = client - .put_object(bucket, &key, &Default::default()) + .put_object(bucket, key, &request_params) .await .expect("put_object should succeed"); @@ -32,7 +30,7 @@ async fn test_put_object(client: &impl ObjectClient, bucket: &str, prefix: &str) request.complete().await.unwrap(); let result = client - .get_object(bucket, &key, None, None) + .get_object(bucket, key, None, None) .await .expect("get_object should succeed"); check_get_result(result, None, &contents[..]).await; @@ -42,18 +40,16 @@ object_client_test!(test_put_object); // Simple test for PUT object. Puts a single, empty object and checks that the (empty) // contents are correct with a GET. -async fn test_put_object_empty(client: &impl ObjectClient, bucket: &str, prefix: &str) { - let key = format!("{prefix}hello"); - +async fn test_put_object_empty(client: &impl ObjectClient, bucket: &str, key: &str, request_params: PutObjectParams) { let request = client - .put_object(bucket, &key, &Default::default()) + .put_object(bucket, key, &request_params) .await .expect("put_object should succeed"); request.complete().await.unwrap(); let result = client - .get_object(bucket, &key, None, None) + .get_object(bucket, key, None, None) .await .expect("get_object should succeed"); check_get_result(result, None, &[]).await; @@ -63,16 +59,19 @@ object_client_test!(test_put_object_empty); // Test for multi-part PUT interface. Splits up a small object into a number of pieces, and streams // the pieces to the object client. Checks contents are correct using a GET. -async fn test_put_object_multi_part(client: &impl ObjectClient, bucket: &str, prefix: &str) { +async fn test_put_object_multi_part( + client: &impl ObjectClient, + bucket: &str, + key: &str, + request_params: PutObjectParams, +) { let mut rng = rand::thread_rng(); - let key = format!("{prefix}hello"); - let mut contents = [0u8; 32]; rng.fill(&mut contents[..]); let mut request = client - .put_object(bucket, &key, &Default::default()) + .put_object(bucket, key, &request_params) .await .expect("put_object failed"); @@ -83,7 +82,7 @@ async fn test_put_object_multi_part(client: &impl ObjectClient, bucket: &str, pr request.complete().await.unwrap(); let result = client - .get_object(bucket, &key, None, None) + .get_object(bucket, key, None, None) .await .expect("get_object failed"); check_get_result(result, None, &contents[..]).await; @@ -93,11 +92,9 @@ object_client_test!(test_put_object_multi_part); // Test for multi-part PUT interface. Splits up a large object into a number of pieces, and streams // the pieces to the object client. Checks contents are correct using a GET. -async fn test_put_object_large(client: &impl ObjectClient, bucket: &str, prefix: &str) { +async fn test_put_object_large(client: &impl ObjectClient, bucket: &str, key: &str, request_params: PutObjectParams) { let mut rng = rand::thread_rng(); - let key = format!("{prefix}hello"); - const OBJECT_SIZE: usize = 32 * 1024 * 1024; const CHUNK_SIZE: usize = 1024 * 1024 + 1; @@ -105,7 +102,7 @@ async fn test_put_object_large(client: &impl ObjectClient, bucket: &str, prefix: rng.fill(&mut contents[..]); let mut request = client - .put_object(bucket, &key, &Default::default()) + .put_object(bucket, key, &request_params) .await .expect("put_object failed"); @@ -115,7 +112,7 @@ async fn test_put_object_large(client: &impl ObjectClient, bucket: &str, prefix: request.complete().await.unwrap(); let result = client - .get_object(bucket, &key, None, None) + .get_object(bucket, key, None, None) .await .expect("get_object failed"); check_get_result(result, None, &contents[..]).await; @@ -124,23 +121,21 @@ async fn test_put_object_large(client: &impl ObjectClient, bucket: &str, prefix: object_client_test!(test_put_object_large); // Test for dropped PUT object. Checks that the GET fails. -async fn test_put_object_dropped(client: &impl ObjectClient, bucket: &str, prefix: &str) { +async fn test_put_object_dropped(client: &impl ObjectClient, bucket: &str, key: &str, request_params: PutObjectParams) { let mut rng = rand::thread_rng(); - let key = format!("{prefix}hello"); - let mut contents = vec![0u8; 32]; rng.fill(&mut contents[..]); let mut request = client - .put_object(bucket, &key, &Default::default()) + .put_object(bucket, key, &request_params) .await .expect("put_object should succeed"); request.write(&contents).await.unwrap(); drop(request); // Drop without calling complete(). - let result = check_get_object(client, bucket, &key).await; + let result = check_get_object(client, bucket, key).await; assert!(result.is_err(), "get_object should fail for dropped PUT"); } @@ -330,3 +325,70 @@ async fn test_put_object_storage_class(storage_class: &str) { assert_eq!(storage_class, attributes.storage_class.unwrap().as_str()); } + +#[cfg(not(feature = "s3express_tests"))] +async fn check_sse(bucket: &String, key: &String, expected_sse: Option<&str>, expected_key: &Option) { + let sdk_client = get_test_sdk_client().await; + let mut request = sdk_client.head_object(); + if cfg!(not(feature = "s3express_tests")) { + request = request.bucket(bucket); + } + let head_object_resp: aws_sdk_s3::operation::head_object::HeadObjectOutput = + request.key(key).send().await.expect("head object should succeed"); + let expected_sse = match expected_sse { + None => aws_sdk_s3::types::ServerSideEncryption::Aes256, + Some("aws:kms") => aws_sdk_s3::types::ServerSideEncryption::AwsKms, + Some("aws:kms:dsse") => aws_sdk_s3::types::ServerSideEncryption::AwsKmsDsse, + _ => panic!("unexpected sse type was used in a test"), + }; + let actual_sse = head_object_resp + .server_side_encryption + .expect("SSE field should always have a value for this test"); + assert_eq!(actual_sse, expected_sse, "unexpected sse type"); + if !matches!(expected_sse, aws_sdk_s3::types::ServerSideEncryption::Aes256) { + assert!( + head_object_resp.ssekms_key_id.is_some(), + "must have a key for non-default encryption methods", + ); + } + if expected_key.is_some() { + // do not check the value of AWS managed key + assert_eq!(&head_object_resp.ssekms_key_id, expected_key, "unexpected sse key") + } +} + +#[test_case(Some("aws:kms"), Some(get_test_kms_key_id()))] +#[test_case(Some("aws:kms"), None)] +#[test_case(Some("aws:kms:dsse"), Some(get_test_kms_key_id()))] +#[test_case(Some("aws:kms:dsse"), None)] +#[test_case(None, None)] +#[tokio::test] +#[cfg(not(feature = "s3express_tests"))] +async fn test_put_object_sse(sse_type: Option<&str>, kms_key_id: Option) { + let bucket = get_test_bucket(); + let client_config = S3ClientConfig::new().endpoint_config(EndpointConfig::new(&get_test_region())); + let client = S3CrtClient::new(client_config).expect("could not create test client"); + let request_params = PutObjectParams::new() + .server_side_encryption(sse_type.map(|value| value.to_owned())) + .ssekms_key_id(kms_key_id.to_owned()); + + let prefix = get_unique_test_prefix("test_put_object_sse"); + let key = format!("{prefix}hello"); + test_put_object(&client, &bucket, &key, request_params.clone()).await; + check_sse(&bucket, &key, sse_type, &kms_key_id).await; + + let prefix = get_unique_test_prefix("test_put_object_sse"); + let key = format!("{prefix}hello"); + test_put_object_empty(&client, &bucket, &key, request_params.clone()).await; + check_sse(&bucket, &key, sse_type, &kms_key_id).await; + + let prefix = get_unique_test_prefix("test_put_object_sse"); + let key = format!("{prefix}hello"); + test_put_object_multi_part(&client, &bucket, &key, request_params.clone()).await; + check_sse(&bucket, &key, sse_type, &kms_key_id).await; + + let prefix = get_unique_test_prefix("test_put_object_sse"); + let key = format!("{prefix}hello"); + test_put_object_large(&client, &bucket, &key, request_params.clone()).await; + check_sse(&bucket, &key, sse_type, &kms_key_id).await; +} diff --git a/mountpoint-s3-crt/src/http/request_response.rs b/mountpoint-s3-crt/src/http/request_response.rs index 7a684e512..7d9ac384d 100644 --- a/mountpoint-s3-crt/src/http/request_response.rs +++ b/mountpoint-s3-crt/src/http/request_response.rs @@ -206,6 +206,14 @@ impl Headers { } } +impl Clone for Headers { + fn clone(&self) -> Self { + // SAFETY: `self.inner` is a valid `aws_http_headers`, and on Clone it's safe and required to increment + // the reference count, dropping new Headers object will decrement it + unsafe { Headers::from_crt(self.inner) } + } +} + impl Drop for Headers { fn drop(&mut self) { // SAFETY: `self.inner` is a valid `aws_http_headers`, and on Drop it's safe to decrement