diff --git a/mountpoint-s3-client/CHANGELOG.md b/mountpoint-s3-client/CHANGELOG.md index 254f0ee15..50ee788b8 100644 --- a/mountpoint-s3-client/CHANGELOG.md +++ b/mountpoint-s3-client/CHANGELOG.md @@ -8,6 +8,9 @@ * `get_object` method now requires a `GetObjectParams` parameter. Two of the existing parameters, `range` and `if_match` have been moved to `GetObjectParams`. ([#1121](https://github.com/awslabs/mountpoint-s3/pull/1121)) +* `increment_read_window` and `read_window_end_offset` methods have been removed from `GetObjectResponse`. + `ClientBackpressureHandle` can be used to interact with flow-control window instead, it can be retrieved from `backpressure_handle` method. + ([#1200](https://github.com/awslabs/mountpoint-s3/pull/1200)) * `head_object` method now requires a `HeadObjectParams` parameter. The structure itself is not required to specify anything to achieve the existing behavior. ([#1083](https://github.com/awslabs/mountpoint-s3/pull/1083)) diff --git a/mountpoint-s3-client/src/failure_client.rs b/mountpoint-s3-client/src/failure_client.rs index efe04f226..f873ab4df 100644 --- a/mountpoint-s3-client/src/failure_client.rs +++ b/mountpoint-s3-client/src/failure_client.rs @@ -217,8 +217,13 @@ pub struct FailureGetResponse { impl GetObjectResponse for FailureGetResponse { + type BackpressureHandle = <::GetObjectResponse as GetObjectResponse>::BackpressureHandle; type ClientError = Client::ClientError; + fn backpressure_handle(&mut self) -> Option<&mut Self::BackpressureHandle> { + self.request.backpressure_handle() + } + fn get_object_metadata(&self) -> ObjectMetadata { self.request.get_object_metadata() } @@ -226,16 +231,6 @@ impl GetObjectRespon fn get_object_checksum(&self) -> Result { self.request.get_object_checksum() } - - fn increment_read_window(self: Pin<&mut Self>, len: usize) { - let this = self.project(); - this.request.increment_read_window(len); - } - - fn read_window_end_offset(self: Pin<&Self>) -> u64 { - let this = self.project_ref(); - this.request.read_window_end_offset() - } } impl Stream for FailureGetResponse { diff --git a/mountpoint-s3-client/src/lib.rs b/mountpoint-s3-client/src/lib.rs index dcf107e16..68df66785 100644 --- a/mountpoint-s3-client/src/lib.rs +++ b/mountpoint-s3-client/src/lib.rs @@ -72,11 +72,11 @@ pub mod config { /// Types used by all object clients pub mod types { pub use super::object_client::{ - Checksum, ChecksumAlgorithm, ChecksumMode, CopyObjectParams, CopyObjectResult, DeleteObjectResult, ETag, - GetBodyPart, GetObjectAttributesParts, GetObjectAttributesResult, GetObjectParams, GetObjectResponse, - HeadObjectParams, HeadObjectResult, ListObjectsResult, ObjectAttribute, ObjectClientResult, ObjectInfo, - ObjectPart, PutObjectParams, PutObjectResult, PutObjectSingleParams, PutObjectTrailingChecksums, RestoreStatus, - UploadChecksum, UploadReview, UploadReviewPart, + Checksum, ChecksumAlgorithm, ChecksumMode, ClientBackpressureHandle, CopyObjectParams, CopyObjectResult, + DeleteObjectResult, ETag, GetBodyPart, GetObjectAttributesParts, GetObjectAttributesResult, GetObjectParams, + GetObjectResponse, HeadObjectParams, HeadObjectResult, ListObjectsResult, ObjectAttribute, ObjectClientResult, + ObjectInfo, ObjectPart, PutObjectParams, PutObjectResult, PutObjectSingleParams, PutObjectTrailingChecksums, + RestoreStatus, UploadChecksum, UploadReview, UploadReviewPart, }; } diff --git a/mountpoint-s3-client/src/mock_client.rs b/mountpoint-s3-client/src/mock_client.rs index 2a9460c8b..2d3dccf3a 100644 --- a/mountpoint-s3-client/src/mock_client.rs +++ b/mountpoint-s3-client/src/mock_client.rs @@ -6,6 +6,7 @@ use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Write; use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, RwLock}; use std::task::{Context, Poll}; use std::time::{Duration, SystemTime}; @@ -26,13 +27,13 @@ use crate::checksums::{ }; use crate::error_metadata::{ClientErrorMetadata, ProvideErrorMetadata}; use crate::object_client::{ - Checksum, ChecksumAlgorithm, ChecksumMode, CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, - DeleteObjectResult, ETag, GetBodyPart, GetObjectAttributesError, GetObjectAttributesParts, - GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectResponse, HeadObjectError, HeadObjectParams, - HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectChecksumError, ObjectClient, - ObjectClientError, ObjectClientResult, ObjectInfo, ObjectMetadata, ObjectPart, PutObjectError, PutObjectParams, - PutObjectRequest, PutObjectResult, PutObjectSingleParams, PutObjectTrailingChecksums, RestoreStatus, - UploadChecksum, UploadReview, UploadReviewPart, + Checksum, ChecksumAlgorithm, ChecksumMode, ClientBackpressureHandle, CopyObjectError, CopyObjectParams, + CopyObjectResult, DeleteObjectError, DeleteObjectResult, ETag, GetBodyPart, GetObjectAttributesError, + GetObjectAttributesParts, GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectResponse, + HeadObjectError, HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, + ObjectChecksumError, ObjectClient, ObjectClientError, ObjectClientResult, ObjectInfo, ObjectMetadata, ObjectPart, + PutObjectError, PutObjectParams, PutObjectRequest, PutObjectResult, PutObjectSingleParams, + PutObjectTrailingChecksums, RestoreStatus, UploadChecksum, UploadReview, UploadReviewPart, }; mod leaky_bucket; @@ -668,6 +669,25 @@ fn validate_checksum( } Ok(provided_checksum) } +#[derive(Clone, Debug)] +pub struct MockBackpressureHandle { + read_window_end_offset: Arc, +} + +impl ClientBackpressureHandle for MockBackpressureHandle { + fn increment_read_window(&mut self, len: usize) { + self.read_window_end_offset.fetch_add(len as u64, Ordering::SeqCst); + } + + fn ensure_read_window(&mut self, desired_end_offset: u64) { + let diff = desired_end_offset.saturating_sub(self.read_window_end_offset()) as usize; + self.increment_read_window(diff); + } + + fn read_window_end_offset(&self) -> u64 { + self.read_window_end_offset.load(Ordering::SeqCst) + } +} #[derive(Debug)] pub struct MockGetObjectResponse { @@ -675,8 +695,7 @@ pub struct MockGetObjectResponse { next_offset: u64, length: usize, part_size: usize, - enable_backpressure: bool, - read_window_end_offset: u64, + backpressure_handle: Option, } impl MockGetObjectResponse { @@ -696,8 +715,13 @@ impl MockGetObjectResponse { #[cfg_attr(not(docsrs), async_trait)] impl GetObjectResponse for MockGetObjectResponse { + type BackpressureHandle = MockBackpressureHandle; type ClientError = MockClientError; + fn backpressure_handle(&mut self) -> Option<&mut Self::BackpressureHandle> { + self.backpressure_handle.as_mut() + } + fn get_object_metadata(&self) -> ObjectMetadata { self.object.object_metadata.clone() } @@ -705,14 +729,6 @@ impl GetObjectResponse for MockGetObjectResponse { fn get_object_checksum(&self) -> Result { Ok(self.object.checksum.clone()) } - - fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { - self.read_window_end_offset += len as u64; - } - - fn read_window_end_offset(self: Pin<&Self>) -> u64 { - self.read_window_end_offset - } } impl Stream for MockGetObjectResponse { @@ -726,10 +742,12 @@ impl Stream for MockGetObjectResponse { let next_read_size = self.part_size.min(self.length); // Simulate backpressure mechanism - if self.enable_backpressure && self.next_offset >= self.read_window_end_offset { - return Poll::Ready(Some(Err(ObjectClientError::ClientError(MockClientError( - "empty read window".into(), - ))))); + if let Some(handle) = &self.backpressure_handle { + if self.next_offset >= handle.read_window_end_offset() { + return Poll::Ready(Some(Err(ObjectClientError::ClientError(MockClientError( + "empty read window".into(), + ))))); + } } let next_part = self.object.read(self.next_offset, next_read_size); @@ -855,13 +873,20 @@ impl ObjectClient for MockClient { (0, object.len()) }; + let backpressure_handle = if self.config.enable_backpressure { + let read_window_end_offset = Arc::new(AtomicU64::new( + next_offset + self.config.initial_read_window_size as u64, + )); + Some(MockBackpressureHandle { read_window_end_offset }) + } else { + None + }; Ok(MockGetObjectResponse { object: object.clone(), next_offset, length, part_size: self.config.part_size, - enable_backpressure: self.config.enable_backpressure, - read_window_end_offset: next_offset + self.config.initial_read_window_size as u64, + backpressure_handle, }) } else { Err(ObjectClientError::ServiceError(GetObjectError::NoSuchKey)) @@ -1199,7 +1224,7 @@ enum MockObjectParts { #[cfg(test)] mod tests { - use futures::{pin_mut, StreamExt}; + use futures::StreamExt; use rand::{Rng, RngCore, SeedableRng}; use rand_chacha::ChaChaRng; use std::ops::Range; @@ -1295,11 +1320,14 @@ mod tests { rng.fill_bytes(&mut body); client.add_object(key, MockObject::from_bytes(&body, ETag::for_tests())); - let get_request = client + let mut get_request = client .get_object("test_bucket", key, &GetObjectParams::new().range(range.clone())) .await .expect("should not fail"); - pin_mut!(get_request); + let mut backpressure_handle = get_request + .backpressure_handle() + .cloned() + .expect("should be able to get a backpressure handle"); let mut accum = vec![]; let mut next_offset = range.as_ref().map(|r| r.start).unwrap_or(0); @@ -1309,10 +1337,8 @@ mod tests { next_offset += body.len() as u64; accum.extend_from_slice(&body[..]); - while next_offset >= get_request.as_ref().read_window_end_offset() { - get_request - .as_mut() - .increment_read_window(backpressure_read_window_size); + while next_offset >= backpressure_handle.read_window_end_offset() { + backpressure_handle.increment_read_window(backpressure_read_window_size); } } let expected_range = range.unwrap_or(0..size as u64); diff --git a/mountpoint-s3-client/src/mock_client/throughput_client.rs b/mountpoint-s3-client/src/mock_client/throughput_client.rs index 953a19565..8b7af9d82 100644 --- a/mountpoint-s3-client/src/mock_client/throughput_client.rs +++ b/mountpoint-s3-client/src/mock_client/throughput_client.rs @@ -20,6 +20,8 @@ use crate::object_client::{ PutObjectResult, PutObjectSingleParams, }; +use super::MockBackpressureHandle; + /// A [MockClient] that rate limits overall download throughput to simulate a target network /// performance without the jitter or service latency of targeting a real service. Note that while /// the rate limit is shared by all downloading streams, there is no fairness, so some streams can @@ -60,16 +62,21 @@ impl ThroughputMockClient { } #[pin_project] -pub struct ThroughputGetObjectRequest { +pub struct ThroughputGetObjectResponse { #[pin] request: MockGetObjectResponse, rate_limiter: LeakyBucket, } #[cfg_attr(not(docsrs), async_trait)] -impl GetObjectResponse for ThroughputGetObjectRequest { +impl GetObjectResponse for ThroughputGetObjectResponse { + type BackpressureHandle = MockBackpressureHandle; type ClientError = MockClientError; + fn backpressure_handle(&mut self) -> Option<&mut Self::BackpressureHandle> { + self.request.backpressure_handle() + } + fn get_object_metadata(&self) -> ObjectMetadata { self.request.object.object_metadata.clone() } @@ -77,19 +84,9 @@ impl GetObjectResponse for ThroughputGetObjectRequest { fn get_object_checksum(&self) -> Result { Ok(self.request.object.checksum.clone()) } - - fn increment_read_window(self: Pin<&mut Self>, len: usize) { - let this = self.project(); - this.request.increment_read_window(len); - } - - fn read_window_end_offset(self: Pin<&Self>) -> u64 { - let this = self.project_ref(); - this.request.read_window_end_offset() - } } -impl Stream for ThroughputGetObjectRequest { +impl Stream for ThroughputGetObjectResponse { type Item = ObjectClientResult; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -107,7 +104,7 @@ impl Stream for ThroughputGetObjectRequest { #[async_trait] impl ObjectClient for ThroughputMockClient { - type GetObjectResponse = ThroughputGetObjectRequest; + type GetObjectResponse = ThroughputGetObjectResponse; type PutObjectRequest = MockPutObjectRequest; type ClientError = MockClientError; @@ -156,7 +153,7 @@ impl ObjectClient for ThroughputMockClient { ) -> ObjectClientResult { let request = self.inner.get_object(bucket, key, params).await?; let rate_limiter = self.rate_limiter.clone(); - Ok(ThroughputGetObjectRequest { request, rate_limiter }) + Ok(ThroughputGetObjectResponse { request, rate_limiter }) } async fn list_objects( diff --git a/mountpoint-s3-client/src/object_client.rs b/mountpoint-s3-client/src/object_client.rs index 592820d05..f23042043 100644 --- a/mountpoint-s3-client/src/object_client.rs +++ b/mountpoint-s3-client/src/object_client.rs @@ -1,6 +1,5 @@ use std::fmt::{self, Debug}; use std::ops::Range; -use std::pin::Pin; use std::time::SystemTime; use async_trait::async_trait; @@ -586,6 +585,29 @@ impl UploadChecksum { } } +/// A handle for controlling backpressure enabled requests. +/// +/// If the client was created with `enable_read_backpressure` set true, +/// each meta request has a flow-control window that shrinks as response +/// body data is downloaded (headers do not affect the size of the window). +/// The client's `initial_read_window` determines the starting size of each meta request's window. +/// If a meta request's flow-control window reaches 0, no further data will be downloaded. +/// If the `initial_read_window` is 0, the request will not start until the window is incremented. +/// Maintain a larger window to keep up a high download throughput, +/// parts cannot download in parallel unless the window is large enough to hold multiple parts. +/// Maintain a smaller window to limit the amount of data buffered in memory. +pub trait ClientBackpressureHandle { + /// Increment the flow-control read window, so that response data continues downloading. + fn increment_read_window(&mut self, len: usize); + + /// Move the upper bound of the read window to the given offset if it's not already there. + fn ensure_read_window(&mut self, desired_end_offset: u64); + + /// Get the upper bound of the read window. When backpressure is enabled, [GetObjectRequest] can + /// return data up to this offset *exclusively*. + fn read_window_end_offset(&self) -> u64; +} + /// A streaming response to a GetObject request. /// /// This struct implements [`futures::Stream`], which you can use to read the body of the object. @@ -595,33 +617,20 @@ impl UploadChecksum { pub trait GetObjectResponse: Stream> + Send + Sync { + type BackpressureHandle: ClientBackpressureHandle + Clone + Send + Sync; type ClientError: std::error::Error + Send + Sync + 'static; + /// Take the backpressure handle from the response. + /// + /// If `enable_read_backpressure` is false this call will return `None`, + /// no backpressure is being applied and data is being downloaded as fast as possible. + fn backpressure_handle(&mut self) -> Option<&mut Self::BackpressureHandle>; + /// Get the object's user defined metadata. fn get_object_metadata(&self) -> ObjectMetadata; /// Get the object's checksum, if uploaded with one fn get_object_checksum(&self) -> Result; - - /// Increment the flow-control window, so that response data continues downloading. - /// - /// If the client was created with `enable_read_backpressure` set true, - /// each meta request has a flow-control window that shrinks as response - /// body data is downloaded (headers do not affect the size of the window). - /// The client's `initial_read_window` determines the starting size of each meta request's window. - /// If a meta request's flow-control window reaches 0, no further data will be downloaded. - /// If the `initial_read_window` is 0, the request will not start until the window is incremented. - /// Maintain a larger window to keep up a high download throughput, - /// parts cannot download in parallel unless the window is large enough to hold multiple parts. - /// Maintain a smaller window to limit the amount of data buffered in memory. - /// - /// If `enable_read_backpressure` is false this call will have no effect, - /// no backpressure is being applied and data is being downloaded as fast as possible. - fn increment_read_window(self: Pin<&mut Self>, len: usize); - - /// Get the upper bound of the current read window. When backpressure is enabled, [GetObjectRequest] can - /// return data up to this offset *exclusively*. - fn read_window_end_offset(self: Pin<&Self>) -> u64; } /// Failures to return object checksum diff --git a/mountpoint-s3-client/src/s3_crt_client/get_object.rs b/mountpoint-s3-client/src/s3_crt_client/get_object.rs index 15d00426a..1efb3a1de 100644 --- a/mountpoint-s3-client/src/s3_crt_client/get_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/get_object.rs @@ -2,6 +2,8 @@ use std::future::Future; use std::ops::Deref; use std::os::unix::prelude::OsStrExt; use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; @@ -9,12 +11,13 @@ use futures::channel::mpsc::UnboundedReceiver; use futures::future::FusedFuture; use futures::{select_biased, Stream}; use mountpoint_s3_crt::http::request_response::{Header, Headers}; -use mountpoint_s3_crt::s3::client::MetaRequestResult; +use mountpoint_s3_crt::s3::client::{MetaRequest, MetaRequestResult}; use pin_project::pin_project; use thiserror::Error; use crate::object_client::{ - Checksum, GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, ObjectClientResult, ObjectMetadata, + Checksum, ClientBackpressureHandle, GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, + ObjectClientResult, ObjectMetadata, }; use crate::s3_crt_client::{ parse_checksum, GetObjectResponse, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, S3RequestError, @@ -34,7 +37,6 @@ impl S3CrtClient { ) -> Result> { let requested_checksums = params.checksum_mode.as_ref() == Some(&ChecksumMode::Enabled); let next_offset = params.range.as_ref().map(|r| r.start).unwrap_or(0); - let read_window_end_offset = next_offset + self.inner.initial_read_window_size as u64; let (part_sender, part_receiver) = futures::channel::mpsc::unbounded(); let (headers_sender, mut headers_receiver) = futures::channel::oneshot::channel(); @@ -127,14 +129,23 @@ impl S3CrtClient { // Guaranteed when select_biased! executes the headers branch. assert!(!request.is_terminated()); + let backpressure_handle = if self.inner.enable_backpressure { + let read_window_end_offset = + Arc::new(AtomicU64::new(next_offset + self.inner.initial_read_window_size as u64)); + Some(S3BackpressureHandle { + read_window_end_offset, + meta_request: request.meta_request.clone(), + }) + } else { + None + }; Ok(S3GetObjectResponse { request, part_receiver, requested_checksums, - enable_backpressure: self.inner.enable_backpressure, + backpressure_handle, headers, next_offset, - read_window_end_offset, }) } } @@ -146,6 +157,30 @@ enum ObjectHeadersError { MissingHeaders, } +#[derive(Clone, Debug)] +pub struct S3BackpressureHandle { + /// Upper bound of the current read window. When backpressure is enabled, [S3GetObjectRequest] + /// can return data up to this offset *exclusively*. + read_window_end_offset: Arc, + meta_request: MetaRequest, +} + +impl ClientBackpressureHandle for S3BackpressureHandle { + fn increment_read_window(&mut self, len: usize) { + self.read_window_end_offset.fetch_add(len as u64, Ordering::SeqCst); + self.meta_request.increment_read_window(len as u64); + } + + fn ensure_read_window(&mut self, desired_end_offset: u64) { + let diff = desired_end_offset.saturating_sub(self.read_window_end_offset()) as usize; + self.increment_read_window(diff); + } + + fn read_window_end_offset(&self) -> u64 { + self.read_window_end_offset.load(Ordering::SeqCst) + } +} + /// A streaming response to a GetObject request. /// /// This struct implements [`futures::Stream`], which you can use to read the body of the object. @@ -159,19 +194,21 @@ pub struct S3GetObjectResponse { #[pin] part_receiver: UnboundedReceiver, requested_checksums: bool, - enable_backpressure: bool, + backpressure_handle: Option, headers: Headers, /// Next offset of the data to be polled from [poll_next] next_offset: u64, - /// Upper bound of the current read window. When backpressure is enabled, [S3GetObjectRequest] - /// can return data up to this offset *exclusively*. - read_window_end_offset: u64, } #[cfg_attr(not(docsrs), async_trait)] impl GetObjectResponse for S3GetObjectResponse { + type BackpressureHandle = S3BackpressureHandle; type ClientError = S3RequestError; + fn backpressure_handle(&mut self) -> Option<&mut Self::BackpressureHandle> { + self.backpressure_handle.as_mut() + } + fn get_object_metadata(&self) -> ObjectMetadata { self.headers .iter() @@ -190,15 +227,6 @@ impl GetObjectResponse for S3GetObjectResponse { parse_checksum(&self.headers).map_err(|e| ObjectChecksumError::HeadersError(Box::new(e))) } - - fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { - self.read_window_end_offset += len as u64; - self.request.meta_request.increment_read_window(len as u64); - } - - fn read_window_end_offset(self: Pin<&Self>) -> u64 { - self.read_window_end_offset - } } impl Stream for S3GetObjectResponse { @@ -224,10 +252,12 @@ impl Stream for S3GetObjectResponse { // the next chunk we want to return error instead of keeping the request blocked. // This prevents a risk of deadlock from using the [S3CrtClient], users must implement // their own logic to block the request if they really want to block a [GetObjectRequest]. - if *this.enable_backpressure && this.read_window_end_offset <= this.next_offset { - return Poll::Ready(Some(Err(ObjectClientError::ClientError( - S3RequestError::EmptyReadWindow, - )))); + if let Some(handle) = &this.backpressure_handle { + if *this.next_offset >= handle.read_window_end_offset() { + return Poll::Ready(Some(Err(ObjectClientError::ClientError( + S3RequestError::EmptyReadWindow, + )))); + } } Poll::Pending } diff --git a/mountpoint-s3-client/tests/common/mod.rs b/mountpoint-s3-client/tests/common/mod.rs index 8bf2b8dfb..d1c97dbd5 100644 --- a/mountpoint-s3-client/tests/common/mod.rs +++ b/mountpoint-s3-client/tests/common/mod.rs @@ -10,7 +10,7 @@ use aws_smithy_runtime_api::client::orchestrator::HttpResponse; use bytes::Bytes; use futures::{pin_mut, Stream, StreamExt}; use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig}; -use mountpoint_s3_client::types::GetObjectResponse; +use mountpoint_s3_client::types::{ClientBackpressureHandle, GetObjectResponse}; use mountpoint_s3_client::S3CrtClient; use mountpoint_s3_crt::common::allocator::Allocator; use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter; @@ -221,14 +221,18 @@ pub async fn check_get_result( /// Check the result of a GET against expected bytes. pub async fn check_backpressure_get_result( read_window: usize, - result: impl GetObjectResponse, + mut response: impl GetObjectResponse, range: Option>, expected: &[u8], ) { let mut accum = vec![]; let mut next_offset = range.map(|r| r.start).unwrap_or(0); - pin_mut!(result); - while let Some(r) = result.next().await { + let mut backpressure_handle = response + .backpressure_handle() + .cloned() + .expect("should be able to get a backpressure handle"); + pin_mut!(response); + while let Some(r) = response.next().await { let (offset, body) = r.expect("get_object body part failed"); assert_eq!(offset, next_offset, "wrong body part offset"); next_offset += body.len() as u64; @@ -236,8 +240,8 @@ pub async fn check_backpressure_get_result( // We run out of data to read if read window is smaller than accum length of data, // so we keeping adding window size, otherwise the request will be blocked. - while next_offset >= result.as_ref().read_window_end_offset() { - result.as_mut().increment_read_window(read_window); + while next_offset >= backpressure_handle.read_window_end_offset() { + backpressure_handle.increment_read_window(read_window); } } assert_eq!(&accum[..], expected, "body does not match"); diff --git a/mountpoint-s3-client/tests/get_object.rs b/mountpoint-s3-client/tests/get_object.rs index c6cd95340..a0cea6851 100644 --- a/mountpoint-s3-client/tests/get_object.rs +++ b/mountpoint-s3-client/tests/get_object.rs @@ -14,7 +14,9 @@ use common::*; use futures::pin_mut; use futures::stream::StreamExt; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; -use mountpoint_s3_client::types::{Checksum, ChecksumMode, ETag, GetObjectParams, GetObjectResponse}; +use mountpoint_s3_client::types::{ + Checksum, ChecksumMode, ClientBackpressureHandle, ETag, GetObjectParams, GetObjectResponse, +}; use mountpoint_s3_client::{ObjectClient, S3CrtClient, S3RequestError}; use test_case::test_case; @@ -183,8 +185,10 @@ async fn test_mutated_during_get_object_backpressure() { .await .unwrap(); - pin_mut!(get_request); - get_request.as_mut().increment_read_window(part_size); + get_request + .backpressure_handle() + .unwrap() + .increment_read_window(part_size); // Verify that the next part is error let next = get_request.next().await.expect("result should not be empty"); diff --git a/mountpoint-s3/src/data_cache/express_data_cache.rs b/mountpoint-s3/src/data_cache/express_data_cache.rs index 2922348be..cf59359af 100644 --- a/mountpoint-s3/src/data_cache/express_data_cache.rs +++ b/mountpoint-s3/src/data_cache/express_data_cache.rs @@ -9,7 +9,8 @@ use bytes::{Bytes, BytesMut}; use futures::{pin_mut, StreamExt}; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError, PutObjectError}; use mountpoint_s3_client::types::{ - ChecksumMode, GetObjectParams, GetObjectResponse, ObjectClientResult, PutObjectSingleParams, UploadChecksum, + ChecksumMode, ClientBackpressureHandle, GetObjectParams, GetObjectResponse, ObjectClientResult, + PutObjectSingleParams, UploadChecksum, }; use mountpoint_s3_client::ObjectClient; use mountpoint_s3_crt::checksums::crc32c::{self, Crc32c}; @@ -97,6 +98,13 @@ where Ok(()) } + + // Ensure the flow-control window is large enough for reading a block of data if backpressure is enabled. + fn ensure_read_window(&self, backpressure_handle: Option<&mut impl ClientBackpressureHandle>) { + if let Some(backpressure_handle) = backpressure_handle { + backpressure_handle.increment_read_window(self.config.block_size as usize); + } + } } #[async_trait] @@ -125,7 +133,7 @@ where } let object_key = get_s3_key(&self.prefix, cache_key, block_idx); - let result = match self + let mut result = match self .client .get_object( &self.bucket_name, @@ -144,12 +152,13 @@ where return Err(DataCacheError::IoFailure(e.into())); } }; + let mut backpressure_handle = result.backpressure_handle().cloned(); - pin_mut!(result); // Guarantee that the request will start even in case of `initial_read_window == 0`. - result.as_mut().increment_read_window(self.config.block_size as usize); + self.ensure_read_window(backpressure_handle.as_mut()); let mut buffer: Bytes = Bytes::new(); + pin_mut!(result); while let Some(chunk) = result.next().await { match chunk { Ok((offset, body)) => { @@ -168,7 +177,7 @@ where }; // Ensure the flow-control window is large enough. - result.as_mut().increment_read_window(self.config.block_size as usize); + self.ensure_read_window(backpressure_handle.as_mut()); } Err(ObjectClientError::ServiceError(GetObjectError::NoSuchKey)) => { metrics::counter!("express_data_cache.block_hit").increment(0); diff --git a/mountpoint-s3/src/prefetch/part_stream.rs b/mountpoint-s3/src/prefetch/part_stream.rs index d925c3a1c..b4354e648 100644 --- a/mountpoint-s3/src/prefetch/part_stream.rs +++ b/mountpoint-s3/src/prefetch/part_stream.rs @@ -2,7 +2,10 @@ use async_stream::try_stream; use bytes::Bytes; use futures::task::{Spawn, SpawnExt}; use futures::{pin_mut, Stream, StreamExt}; -use mountpoint_s3_client::{types::GetObjectParams, types::GetObjectResponse, ObjectClient}; +use mountpoint_s3_client::{ + types::{ClientBackpressureHandle, GetObjectParams, GetObjectResponse}, + ObjectClient, +}; use std::marker::{Send, Sync}; use std::sync::Arc; use std::{fmt::Debug, ops::Range}; @@ -359,18 +362,18 @@ fn read_from_request<'a, Client: ObjectClient + 'a>( request_range: Range, ) -> impl Stream> + 'a { try_stream! { - let request = client + let mut request = client .get_object(&bucket, id.key(), &GetObjectParams::new().range(Some(request_range.clone())).if_match(Some(id.etag().clone()))) .await .inspect_err(|e| error!(key=id.key(), error=?e, "GetObject request failed")) .map_err(PrefetchReadError::GetRequestFailed)?; - pin_mut!(request); - let read_window_size_diff = backpressure_limiter - .read_window_end_offset() - .saturating_sub(request.as_ref().read_window_end_offset()) as usize; - request.as_mut().increment_read_window(read_window_size_diff); + let mut backpressure_handle = request.backpressure_handle().cloned(); + if let Some(handle) = backpressure_handle.as_mut() { + handle.ensure_read_window(backpressure_limiter.read_window_end_offset()); + } + pin_mut!(request); while let Some(next) = request.next().await { let (offset, body) = next .inspect_err(|e| error!(key=id.key(), error=?e, "GetObject body part failed")) @@ -394,9 +397,10 @@ fn read_from_request<'a, Client: ObjectClient + 'a>( metrics::histogram!("s3.client.read_window_excess_bytes").record(excess_bytes as f64); } // Blocks if read window increment if it's not enough to read the next offset - if let Some(next_read_window_offset) = backpressure_limiter.wait_for_read_window_increment(next_offset).await? { - let diff = next_read_window_offset.saturating_sub(request.as_ref().read_window_end_offset()) as usize; - request.as_mut().increment_read_window(diff); + if let Some(next_read_window_end_offset) = backpressure_limiter.wait_for_read_window_increment(next_offset).await? { + if let Some(handle) = backpressure_handle.as_mut() { + handle.ensure_read_window(next_read_window_end_offset); + } } } trace!("request finished");