diff --git a/mountpoint-s3-client/CHANGELOG.md b/mountpoint-s3-client/CHANGELOG.md index 3ccd79365..17494b899 100644 --- a/mountpoint-s3-client/CHANGELOG.md +++ b/mountpoint-s3-client/CHANGELOG.md @@ -5,6 +5,9 @@ * Add parameter to request checksum information as part of a `HeadObject` request. If specified, the result should contain the checksum for the object if available in the S3 response. ([#1083](https://github.com/awslabs/mountpoint-s3/pull/1083)) +* Add parameter to request checksum information as part of a `GetObject` request. + If specified, calling `get_object_checksum` on `GetObjectRequest` will return the checksum information. + ([#1123](https://github.com/awslabs/mountpoint-s3/pull/1123)) * Expose checksum algorithm in `ListObjectsResult`'s `ObjectInfo` struct. ([#1086](https://github.com/awslabs/mountpoint-s3/pull/1086), [#1093](https://github.com/awslabs/mountpoint-s3/pull/1093)) diff --git a/mountpoint-s3-client/src/failure_client.rs b/mountpoint-s3-client/src/failure_client.rs index 4db79dbde..90d752e8a 100644 --- a/mountpoint-s3-client/src/failure_client.rs +++ b/mountpoint-s3-client/src/failure_client.rs @@ -14,7 +14,7 @@ use mountpoint_s3_crt::s3::client::BufferPoolUsageStats; use pin_project::pin_project; use crate::object_client::{ - CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, GetBodyPart, + Checksum, CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, GetBodyPart, GetObjectAttributesError, GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectRequest, HeadObjectError, HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectClient, ObjectClientError, ObjectClientResult, ObjectMetadata, PutObjectError, PutObjectParams, @@ -223,6 +223,10 @@ impl GetObjectReques self.request.get_object_metadata().await } + async fn get_object_checksum(&self) -> ObjectClientResult { + self.request.get_object_checksum().await + } + fn increment_read_window(self: Pin<&mut Self>, len: usize) { let this = self.project(); this.request.increment_read_window(len); diff --git a/mountpoint-s3-client/src/mock_client.rs b/mountpoint-s3-client/src/mock_client.rs index 7cc5b13cb..c2efad642 100644 --- a/mountpoint-s3-client/src/mock_client.rs +++ b/mountpoint-s3-client/src/mock_client.rs @@ -538,6 +538,10 @@ impl GetObjectRequest for MockGetObjectRequest { Ok(self.object.object_metadata.clone()) } + async fn get_object_checksum(&self) -> ObjectClientResult { + Ok(self.object.checksum.clone()) + } + fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { self.read_window_end_offset += len as u64; } diff --git a/mountpoint-s3-client/src/mock_client/throughput_client.rs b/mountpoint-s3-client/src/mock_client/throughput_client.rs index dc4671503..a64d1b31e 100644 --- a/mountpoint-s3-client/src/mock_client/throughput_client.rs +++ b/mountpoint-s3-client/src/mock_client/throughput_client.rs @@ -13,7 +13,7 @@ use crate::mock_client::{ MockClient, MockClientConfig, MockClientError, MockGetObjectRequest, MockObject, MockPutObjectRequest, }; use crate::object_client::{ - CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, GetBodyPart, + Checksum, CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, GetBodyPart, GetObjectAttributesError, GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectRequest, HeadObjectError, HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectClient, ObjectClientResult, ObjectMetadata, PutObjectError, PutObjectParams, PutObjectResult, @@ -74,6 +74,10 @@ impl GetObjectRequest for ThroughputGetObjectRequest { Ok(self.request.object.object_metadata.clone()) } + async fn get_object_checksum(&self) -> ObjectClientResult { + Ok(self.request.object.checksum.clone()) + } + fn increment_read_window(self: Pin<&mut Self>, len: usize) { let this = self.project(); this.request.increment_read_window(len); diff --git a/mountpoint-s3-client/src/object_client.rs b/mountpoint-s3-client/src/object_client.rs index e64e3fd4c..5240ae952 100644 --- a/mountpoint-s3-client/src/object_client.rs +++ b/mountpoint-s3-client/src/object_client.rs @@ -189,6 +189,7 @@ pub enum GetObjectError { pub struct GetObjectParams { pub range: Option>, pub if_match: Option, + pub checksum_mode: Option, } impl GetObjectParams { @@ -208,6 +209,12 @@ impl GetObjectParams { self.if_match = value; self } + + /// Set option to retrieve checksum as part of the GetObject request + pub fn checksum_mode(mut self, value: Option) -> Self { + self.checksum_mode = value; + self + } } /// Result of a [`list_objects`](ObjectClient::list_objects) request @@ -256,7 +263,7 @@ impl HeadObjectParams { /// Enable [ChecksumMode] to retrieve object checksums #[non_exhaustive] -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum ChecksumMode { /// Retrieve checksums Enabled, @@ -562,6 +569,9 @@ pub trait GetObjectRequest: /// when they're read. async fn get_object_metadata(&self) -> ObjectClientResult; + /// Get the object's checksum, if uploaded with one + async fn get_object_checksum(&self) -> ObjectClientResult; + /// Increment the flow-control window, so that response data continues downloading. /// /// If the client was created with `enable_read_backpressure` set true, diff --git a/mountpoint-s3-client/src/s3_crt_client.rs b/mountpoint-s3-client/src/s3_crt_client.rs index a9817f71e..ed2b9f236 100644 --- a/mountpoint-s3-client/src/s3_crt_client.rs +++ b/mountpoint-s3-client/src/s3_crt_client.rs @@ -18,7 +18,7 @@ use mountpoint_s3_crt::auth::signing_config::SigningConfig; use mountpoint_s3_crt::common::allocator::Allocator; use mountpoint_s3_crt::common::string::AwsString; use mountpoint_s3_crt::common::uri::Uri; -use mountpoint_s3_crt::http::request_response::{Header, Headers, Message}; +use mountpoint_s3_crt::http::request_response::{Header, Headers, HeadersError, Message}; use mountpoint_s3_crt::io::channel_bootstrap::{ClientBootstrap, ClientBootstrapOptions}; use mountpoint_s3_crt::io::event_loop::EventLoopGroup; use mountpoint_s3_crt::io::host_resolver::{AddressKinds, HostResolver, HostResolverDefaultOptions}; @@ -1119,6 +1119,21 @@ fn extract_range_header(headers: &Headers) -> Option> { Some(start..end + 1) } +/// Extract the [Checksum] information from headers +fn parse_checksum(headers: &Headers) -> Result { + let checksum_crc32 = headers.get_as_optional_string("x-amz-checksum-crc32")?; + let checksum_crc32c = headers.get_as_optional_string("x-amz-checksum-crc32c")?; + let checksum_sha1 = headers.get_as_optional_string("x-amz-checksum-sha1")?; + let checksum_sha256 = headers.get_as_optional_string("x-amz-checksum-sha256")?; + + Ok(Checksum { + checksum_crc32, + checksum_crc32c, + checksum_sha1, + checksum_sha256, + }) +} + /// Try to parse a modeled error out of a failing meta request fn try_parse_generic_error(request_result: &MetaRequestResult) -> Option { /// Look for a redirect header pointing to a different region for the bucket @@ -1656,4 +1671,22 @@ mod tests { }; assert_eq!(error, error_code.into()); } + + #[test] + fn test_checksum_sha256() { + let mut headers = Headers::new(&Allocator::default()).unwrap(); + let value = "QwzjTQIHJO11oZbfwq1nx3dy0Wk="; + let header = Header::new("x-amz-checksum-sha256", value.to_owned()); + headers.add_header(&header).unwrap(); + + let checksum = parse_checksum(&headers).expect("failed to parse headers"); + assert_eq!(checksum.checksum_crc32, None, "other checksums shouldn't be set"); + assert_eq!(checksum.checksum_crc32c, None, "other checksums shouldn't be set"); + assert_eq!(checksum.checksum_sha1, None, "other checksums shouldn't be set"); + assert_eq!( + checksum.checksum_sha256, + Some(value.to_owned()), + "sha256 header should match" + ); + } } diff --git a/mountpoint-s3-client/src/s3_crt_client/get_object.rs b/mountpoint-s3-client/src/s3_crt_client/get_object.rs index 710e01769..f2be3c0df 100644 --- a/mountpoint-s3-client/src/s3_crt_client/get_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/get_object.rs @@ -10,23 +10,26 @@ use std::task::{Context, Poll}; use futures::channel::mpsc::UnboundedReceiver; use futures::Stream; use mountpoint_s3_crt::common::error::Error; -use mountpoint_s3_crt::http::request_response::Header; +use mountpoint_s3_crt::http::request_response::{Header, Headers}; use mountpoint_s3_crt::s3::client::MetaRequestResult; use pin_project::pin_project; use thiserror::Error; use crate::object_client::{ - GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, ObjectClientResult, ObjectMetadata, + Checksum, GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, ObjectClientResult, ObjectMetadata, }; use crate::s3_crt_client::{ - GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, S3RequestError, + parse_checksum, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, S3RequestError, }; +use crate::types::ChecksumMode; /// Failures to return object metadata #[derive(Clone, Error, Debug)] -pub enum ObjectMetadataError { - #[error("error occurred fetching object metadata")] - ObjectMetadataError, +pub enum ObjectHeadersError { + #[error("unknown error occurred receiving object headers")] + UnknownError, + #[error("requested object checksums, but did not specify it in the request")] + DidNotRequestChecksums, } impl S3CrtClient { @@ -50,6 +53,14 @@ impl S3CrtClient { .set_header(&Header::new("accept", "*/*")) .map_err(S3RequestError::construction_failure)?; + let requested_checksums = params.checksum_mode.as_ref() == Some(&ChecksumMode::Enabled); + if requested_checksums { + // Add checksum header to receive object checksums. + message + .set_header(&Header::new("x-amz-checksum-mode", "enabled")) + .map_err(S3RequestError::construction_failure)?; + } + if let Some(etag) = params.if_match.as_ref() { // Return the object only if its entity tag (ETag) is matched message @@ -79,35 +90,24 @@ impl S3CrtClient { let mut options = S3CrtClientInner::new_meta_request_options(message, S3Operation::GetObject); options.part_size(self.inner.read_part_size as u64); - let object_metadata = AsyncCell::shared(); + let object_headers = AsyncCell::shared(); - let object_metadata_setter_on_headers = object_metadata.clone(); - let object_metadata_setter_on_finish = object_metadata.clone(); + let object_headers_setter_on_headers = object_headers.clone(); + let object_headers_setter_on_finish = object_headers.clone(); let request = self.inner.make_meta_request_from_options( options, span, |_| (), move |headers, status| { - // Headers can be returned multiple times, but the object metadata doesn't change. + // Headers can be returned multiple times, but the metadata/checksums don't change. // Explicitly ignore the case where we've already set object metadata. - // Only set metadata if we have a 2xx status code. If we only get other status - // codes, then on_finish cancels. + // Only set headers if we have a 2xx status code. If we only get other status codes, + // then on_finish sets an error. if (200..300).contains(&status) { - // This isn't to do with safety, only minor performance gains. - if !object_metadata_setter_on_headers.is_set() { - let object_metadata = headers - .iter() - .filter_map(|(key, value)| { - let metadata_header = key.to_str()?.strip_prefix("x-amz-meta-")?; - let value = value.to_str()?; - Some((metadata_header.to_string(), value.to_string())) - }) - .collect(); - // Don't overwrite if already set. - object_metadata_setter_on_headers.or_set(Ok(object_metadata)); - } + // Don't overwrite if already set - the first headers are fine. + object_headers_setter_on_headers.or_set(Ok(headers.clone())); } }, move |offset, data| { @@ -115,7 +115,7 @@ impl S3CrtClient { }, move |result| { // FIXME - Ideally we'd include a reason why we failed here. - object_metadata_setter_on_finish.or_set(Err(ObjectMetadataError::ObjectMetadataError)); + object_headers_setter_on_finish.or_set(Err(ObjectHeadersError::UnknownError)); if result.is_err() { Err(parse_get_object_error(result).map(ObjectClientError::ServiceError)) } else { @@ -128,8 +128,9 @@ impl S3CrtClient { request, finish_receiver: receiver, finished: false, + requested_checksums, enable_backpressure: self.inner.enable_backpressure, - object_metadata, + headers: object_headers, initial_read_window_empty: self.inner.initial_read_window_size == 0, next_offset, read_window_end_offset, @@ -150,8 +151,9 @@ pub struct S3GetObjectRequest { #[pin] finish_receiver: UnboundedReceiver>, finished: bool, + requested_checksums: bool, enable_backpressure: bool, - object_metadata: Arc>>, + headers: Arc>>, initial_read_window_empty: bool, /// Next offset of the data to be polled from [poll_next] next_offset: u64, @@ -160,22 +162,47 @@ pub struct S3GetObjectRequest { read_window_end_offset: u64, } -#[cfg_attr(not(docsrs), async_trait)] -impl GetObjectRequest for S3GetObjectRequest { - type ClientError = S3RequestError; - - async fn get_object_metadata(&self) -> ObjectClientResult { - match self.object_metadata.try_get() { +impl S3GetObjectRequest { + async fn get_object_headers(&self) -> ObjectClientResult { + match self.headers.try_get() { Some(result) => result, None => { if self.enable_backpressure && self.initial_read_window_empty { return Err(ObjectClientError::ClientError(S3RequestError::EmptyReadWindow)); } - self.object_metadata.get().await + self.headers.get().await } } .map_err(|_| ObjectClientError::ClientError(S3RequestError::RequestCanceled)) } +} + +#[cfg_attr(not(docsrs), async_trait)] +impl GetObjectRequest for S3GetObjectRequest { + type ClientError = S3RequestError; + + async fn get_object_metadata(&self) -> ObjectClientResult { + let headers = self.get_object_headers().await?; + Ok(headers + .iter() + .filter_map(|(key, value)| { + let metadata_header = key.to_str()?.strip_prefix("x-amz-meta-")?; + let value = value.to_str()?; + Some((metadata_header.to_string(), value.to_string())) + }) + .collect()) + } + + async fn get_object_checksum(&self) -> ObjectClientResult { + if !self.requested_checksums { + return Err(ObjectClientError::ClientError(S3RequestError::InternalError(Box::new( + ObjectHeadersError::DidNotRequestChecksums, + )))); + } + + let headers = self.get_object_headers().await?; + parse_checksum(&headers).map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e)))) + } fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { self.read_window_end_offset += len as u64; diff --git a/mountpoint-s3-client/src/s3_crt_client/head_object.rs b/mountpoint-s3-client/src/s3_crt_client/head_object.rs index 41bff222e..e95bd8606 100644 --- a/mountpoint-s3-client/src/s3_crt_client/head_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/head_object.rs @@ -11,9 +11,9 @@ use time::OffsetDateTime; use tracing::error; use crate::object_client::{ - Checksum, HeadObjectError, HeadObjectParams, HeadObjectResult, ObjectClientError, ObjectClientResult, RestoreStatus, + HeadObjectError, HeadObjectParams, HeadObjectResult, ObjectClientError, ObjectClientResult, RestoreStatus, }; -use crate::s3_crt_client::{S3CrtClient, S3Operation, S3RequestError}; +use crate::s3_crt_client::{parse_checksum, S3CrtClient, S3Operation, S3RequestError}; use super::ChecksumMode; @@ -65,20 +65,6 @@ impl HeadObjectResult { Ok(Some(RestoreStatus::Restored { expiry: expiry.into() })) } - fn parse_checksum(headers: &Headers) -> Result { - let checksum_crc32 = headers.get_as_optional_string("x-amz-checksum-crc32")?; - let checksum_crc32c = headers.get_as_optional_string("x-amz-checksum-crc32c")?; - let checksum_sha1 = headers.get_as_optional_string("x-amz-checksum-sha1")?; - let checksum_sha256 = headers.get_as_optional_string("x-amz-checksum-sha256")?; - - Ok(Checksum { - checksum_crc32, - checksum_crc32c, - checksum_sha1, - checksum_sha256, - }) - } - /// Parse from HeadObject headers fn parse_from_hdr(headers: &Headers) -> Result { let last_modified = OffsetDateTime::parse(&headers.get_as_string("Last-Modified")?, &Rfc2822) @@ -88,7 +74,7 @@ impl HeadObjectResult { let etag = headers.get_as_string("Etag")?; let storage_class = headers.get_as_optional_string("x-amz-storage-class")?; let restore_status = Self::parse_restore_status(headers)?; - let checksum = Self::parse_checksum(headers)?; + let checksum = parse_checksum(headers)?; let result = HeadObjectResult { size, last_modified, @@ -236,24 +222,6 @@ mod tests { }; } - #[test] - fn test_checksum_sha256() { - let mut headers = Headers::new(&Allocator::default()).unwrap(); - let value = "QwzjTQIHJO11oZbfwq1nx3dy0Wk="; - let header = Header::new("x-amz-checksum-sha256", value.to_owned()); - headers.add_header(&header).unwrap(); - - let checksum = HeadObjectResult::parse_checksum(&headers).expect("failed to parse headers"); - assert_eq!(checksum.checksum_crc32, None, "other checksums shouldn't be set"); - assert_eq!(checksum.checksum_crc32c, None, "other checksums shouldn't be set"); - assert_eq!(checksum.checksum_sha1, None, "other checksums shouldn't be set"); - assert_eq!( - checksum.checksum_sha256, - Some(value.to_owned()), - "sha256 header should match" - ); - } - #[test] fn test_parse_restore_empty() { let headers = Headers::new(&Allocator::default()).unwrap(); diff --git a/mountpoint-s3-client/tests/get_object.rs b/mountpoint-s3-client/tests/get_object.rs index 360fde1e2..9abfae092 100644 --- a/mountpoint-s3-client/tests/get_object.rs +++ b/mountpoint-s3-client/tests/get_object.rs @@ -8,12 +8,13 @@ use std::option::Option::None; use std::str::FromStr; use aws_sdk_s3::primitives::ByteStream; +use aws_sdk_s3::types::ChecksumAlgorithm; use bytes::Bytes; use common::*; use futures::pin_mut; use futures::stream::StreamExt; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; -use mountpoint_s3_client::types::{ETag, GetObjectParams, GetObjectRequest}; +use mountpoint_s3_client::types::{Checksum, ChecksumMode, ETag, GetObjectParams, GetObjectRequest}; use mountpoint_s3_client::{ObjectClient, S3CrtClient, S3RequestError}; use test_case::test_case; @@ -465,3 +466,88 @@ async fn test_get_object_user_metadata_after_stream(size: usize, metadata: HashM .expect("should return metadata"); assert_eq!(actual_metadata, metadata); } + +#[test_case(ChecksumAlgorithm::Crc32)] +#[test_case(ChecksumAlgorithm::Crc32C)] +#[test_case(ChecksumAlgorithm::Sha1)] +#[test_case(ChecksumAlgorithm::Sha256)] +#[tokio::test] +async fn test_get_object_checksum(checksum_algorithm: ChecksumAlgorithm) { + let sdk_client = get_test_sdk_client().await; + let (bucket, prefix) = get_test_bucket_and_prefix("test_get_object_checksum"); + + let key = format!("{prefix}/test"); + let body = vec![0x42; 42]; + let put_object_output = sdk_client + .put_object() + .bucket(&bucket) + .key(&key) + .body(ByteStream::from(body.clone())) + .checksum_algorithm(checksum_algorithm.clone()) + .send() + .await + .unwrap(); + + let client: S3CrtClient = get_test_client(); + + let result = client + .get_object( + &bucket, + &key, + &GetObjectParams::new().checksum_mode(Some(ChecksumMode::Enabled)), + ) + .await + .expect("get_object should succeed"); + + let checksum: Checksum = result.get_object_checksum().await.expect("should return checksum"); + + match checksum_algorithm { + ChecksumAlgorithm::Crc32 => assert_eq!( + checksum.checksum_crc32, + put_object_output.checksum_crc32().map(|s| s.to_string()) + ), + ChecksumAlgorithm::Crc32C => assert_eq!( + checksum.checksum_crc32c, + put_object_output.checksum_crc32_c().map(|s| s.to_string()) + ), + ChecksumAlgorithm::Sha1 => assert_eq!( + checksum.checksum_sha1, + put_object_output.checksum_sha1().map(|s| s.to_string()) + ), + ChecksumAlgorithm::Sha256 => assert_eq!( + checksum.checksum_sha256, + put_object_output.checksum_sha256().map(|s| s.to_string()) + ), + _ => unimplemented!("This algorithm is not supported"), + } +} + +#[tokio::test] +async fn test_get_object_checksum_checksums_disabled() { + let sdk_client = get_test_sdk_client().await; + let (bucket, prefix) = get_test_bucket_and_prefix("test_get_object_checksum"); + + let key = format!("{prefix}/test"); + let body = vec![0x42; 42]; + sdk_client + .put_object() + .bucket(&bucket) + .key(&key) + .body(ByteStream::from(body.clone())) + .checksum_algorithm(ChecksumAlgorithm::Crc32) + .send() + .await + .unwrap(); + + let client: S3CrtClient = get_test_client(); + + let result = client + .get_object(&bucket, &key, &GetObjectParams::new()) + .await + .expect("get_object should succeed"); + + result + .get_object_checksum() + .await + .expect_err("should not return a checksum object as not requested"); +}