Skip to content

Commit

Permalink
Refactor ObjectClient.get_object to use an GetObjectParams parame…
Browse files Browse the repository at this point in the history
…ter (#1121)

<!--
The title and description of pull requests will be used when creating a
squash commit to the base branch (usually `main`).
Please keep them both up-to-date as the code change evolves, to ensure
that the commit message is useful for future readers.
-->

## Description of change

Refactor `ObjectClient.get_object` to use an `&GetObjectParams`
parameter.

Migrates the two existing parameters, `range` and `if_match` to
`GetObjectParams` and changes all call sites.

<!--
    Please describe your contribution here.
    What is the change and why are you making it?
-->

Relevant issues: N/A

## Does this change impact existing behavior?

No

<!-- Please confirm there's no breaking change, or call our any behavior
changes you think are necessary. -->

## Does this change need a changelog entry in any of the crates?

Yes. Breaking change in mountpoint-s3-client. 

<!--
    Please confirm yes or no.
    If no, add justification. If unsure, ask a reviewer.

    You can find the changelog for each crate here:
-
https://github.com/awslabs/mountpoint-s3/blob/main/mountpoint-s3/CHANGELOG.md
-
https://github.com/awslabs/mountpoint-s3/blob/main/mountpoint-s3-client/CHANGELOG.md
-
https://github.com/awslabs/mountpoint-s3/blob/main/mountpoint-s3-crt/CHANGELOG.md
-
https://github.com/awslabs/mountpoint-s3/blob/main/mountpoint-s3-crt-sys/CHANGELOG.md
-->

---

By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license and I agree to the terms of
the [Developer Certificate of Origin
(DCO)](https://developercertificate.org/).

---------

Signed-off-by: Simon Beal <[email protected]>
  • Loading branch information
muddyfish authored Nov 8, 2024
1 parent 7d01885 commit 36d386e
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 114 deletions.
4 changes: 4 additions & 0 deletions mountpoint-s3-client/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
* Both `ObjectInfo` and `ChecksumAlgorithm` structs are now marked `non_exhaustive`, to indicate that new fields may be added in the future.
`ChecksumAlgorithm` no longer implements `Copy`.
([#1086](https://github.com/awslabs/mountpoint-s3/pull/1086))
* `get_object` method now requires a `GetObjectParams` parameter.
Two of the existing parameters, `range` and `if_match` have been moved to `GetObjectParams`.
([#1121](https://github.com/awslabs/mountpoint-s3/pull/1121))


## v0.11.0 (October 17, 2024)

Expand Down
4 changes: 2 additions & 2 deletions mountpoint-s3-client/examples/client_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use futures::StreamExt;
use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig};
use mountpoint_s3_client::mock_client::throughput_client::ThroughputMockClient;
use mountpoint_s3_client::mock_client::{MockClientConfig, MockObject};
use mountpoint_s3_client::types::ETag;
use mountpoint_s3_client::types::{ETag, GetObjectParams};
use mountpoint_s3_client::{ObjectClient, S3CrtClient};
use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter;
use tracing_subscriber::fmt::Subscriber;
Expand Down Expand Up @@ -46,7 +46,7 @@ fn run_benchmark(
scope.spawn(|| {
futures::executor::block_on(async move {
let mut request = client
.get_object(bucket, key, None, None)
.get_object(bucket, key, &GetObjectParams::new())
.await
.expect("couldn't create get request");
let mut request = pin!(request);
Expand Down
3 changes: 2 additions & 1 deletion mountpoint-s3-client/examples/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
use clap::{Arg, Command};
use futures::StreamExt;
use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig};
use mountpoint_s3_client::types::GetObjectParams;
use mountpoint_s3_client::{ObjectClient, S3CrtClient};
use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter;
use regex::Regex;
Expand Down Expand Up @@ -58,7 +59,7 @@ fn main() {
let last_offset_clone = Arc::clone(&last_offset);
futures::executor::block_on(async move {
let mut request = client
.get_object(bucket, key, range, None)
.get_object(bucket, key, &GetObjectParams::new().range(range))
.await
.expect("couldn't create get request");
loop {
Expand Down
32 changes: 12 additions & 20 deletions mountpoint-s3-client/src/failure_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};
Expand All @@ -15,11 +14,11 @@ use mountpoint_s3_crt::s3::client::BufferPoolUsageStats;
use pin_project::pin_project;

use crate::object_client::{
CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, ETag, GetBodyPart,
GetObjectAttributesError, GetObjectAttributesResult, GetObjectError, GetObjectRequest, HeadObjectError,
HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectClient,
ObjectClientError, ObjectClientResult, ObjectMetadata, PutObjectError, PutObjectParams, PutObjectRequest,
PutObjectResult, PutObjectSingleParams, UploadReview,
CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, GetBodyPart,
GetObjectAttributesError, GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectRequest,
HeadObjectError, HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute,
ObjectClient, ObjectClientError, ObjectClientResult, ObjectMetadata, PutObjectError, PutObjectParams,
PutObjectRequest, PutObjectResult, PutObjectSingleParams, UploadReview,
};

// Wrapper for injecting failures into a get stream or a put request
Expand All @@ -36,8 +35,7 @@ pub struct FailureClient<Client: ObjectClient, State, RequestWrapperState> {
&mut State,
&str,
&str,
Option<Range<u64>>,
Option<ETag>,
&GetObjectParams,
) -> Result<
FailureRequestWrapper<Client::ClientError, RequestWrapperState>,
ObjectClientError<GetObjectError, Client::ClientError>,
Expand Down Expand Up @@ -123,17 +121,10 @@ where
&self,
bucket: &str,
key: &str,
range: Option<Range<u64>>,
if_match: Option<ETag>,
params: &GetObjectParams,
) -> ObjectClientResult<Self::GetObjectRequest, GetObjectError, Self::ClientError> {
let wrapper = (self.get_object_cb)(
&mut *self.state.lock().unwrap(),
bucket,
key,
range.clone(),
if_match.clone(),
)?;
let request = self.client.get_object(bucket, key, range, if_match).await?;
let wrapper = (self.get_object_cb)(&mut *self.state.lock().unwrap(), bucket, key, params)?;
let request = self.client.get_object(bucket, key, params).await?;
Ok(FailureGetRequest {
state: wrapper.state,
result_fn: wrapper.result_fn,
Expand Down Expand Up @@ -364,7 +355,7 @@ pub fn countdown_failure_client<Client: ObjectClient>(
FailureClient {
client,
state,
get_object_cb: |state, _bucket, _key, _range, _if_match| {
get_object_cb: |state, _bucket, _key, _get_object_params| {
state.get_count += 1;
let (fail_count, error) = if let Some(result) = state.get_failures.remove(&state.get_count) {
let (fail_count, error) = result?;
Expand Down Expand Up @@ -443,6 +434,7 @@ pub fn countdown_failure_client<Client: ObjectClient>(
mod tests {
use super::*;
use crate::mock_client::{MockClient, MockClientConfig, MockClientError, MockObject};
use crate::types::ETag;
use std::collections::HashSet;

#[tokio::test]
Expand Down Expand Up @@ -486,7 +478,7 @@ mod tests {

let fail_set = HashSet::from([2, 4, 5]);
for i in 1..=6 {
let r = fail_client.get_object(bucket, key, None, None).await;
let r = fail_client.get_object(bucket, key, &GetObjectParams::new()).await;
if fail_set.contains(&i) {
assert!(r.is_err());
} else {
Expand Down
8 changes: 4 additions & 4 deletions mountpoint-s3-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//!
//! let client = S3CrtClient::new(Default::default()).expect("client construction failed");
//!
//! let response = client.get_object("my-bucket", "my-key", None, None).await.expect("get_object failed");
//! let response = client.get_object("my-bucket", "my-key", &GetObjectParams::new().await.expect("get_object failed"));
//! let body = response.map_ok(|(offset, body)| body.to_vec()).try_concat().await.expect("body streaming failed");
//! # }
//! ```
Expand Down Expand Up @@ -73,9 +73,9 @@ pub mod config {
pub mod types {
pub use super::object_client::{
Checksum, ChecksumAlgorithm, ChecksumMode, CopyObjectParams, CopyObjectResult, DeleteObjectResult, ETag,
GetBodyPart, GetObjectAttributesParts, GetObjectAttributesResult, GetObjectRequest, HeadObjectParams,
HeadObjectResult, ListObjectsResult, ObjectAttribute, ObjectClientResult, ObjectInfo, ObjectPart,
PutObjectParams, PutObjectResult, PutObjectSingleParams, PutObjectTrailingChecksums, RestoreStatus,
GetBodyPart, GetObjectAttributesParts, GetObjectAttributesResult, GetObjectParams, GetObjectRequest,
HeadObjectParams, HeadObjectResult, ListObjectsResult, ObjectAttribute, ObjectClientResult, ObjectInfo,
ObjectPart, PutObjectParams, PutObjectResult, PutObjectSingleParams, PutObjectTrailingChecksums, RestoreStatus,
UploadChecksum, UploadReview, UploadReviewPart,
};
}
Expand Down
60 changes: 36 additions & 24 deletions mountpoint-s3-client/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt::Write;
use std::ops::Range;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
Expand All @@ -28,10 +27,11 @@ use crate::error_metadata::{ClientErrorMetadata, ProvideErrorMetadata};
use crate::object_client::{
Checksum, ChecksumAlgorithm, ChecksumMode, CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError,
DeleteObjectResult, ETag, GetBodyPart, GetObjectAttributesError, GetObjectAttributesParts,
GetObjectAttributesResult, GetObjectError, GetObjectRequest, HeadObjectError, HeadObjectParams, HeadObjectResult,
ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectClient, ObjectClientError, ObjectClientResult,
ObjectInfo, ObjectMetadata, ObjectPart, PutObjectError, PutObjectParams, PutObjectRequest, PutObjectResult,
PutObjectSingleParams, PutObjectTrailingChecksums, RestoreStatus, UploadChecksum, UploadReview, UploadReviewPart,
GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectRequest, HeadObjectError, HeadObjectParams,
HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectClient, ObjectClientError,
ObjectClientResult, ObjectInfo, ObjectMetadata, ObjectPart, PutObjectError, PutObjectParams, PutObjectRequest,
PutObjectResult, PutObjectSingleParams, PutObjectTrailingChecksums, RestoreStatus, UploadChecksum, UploadReview,
UploadReviewPart,
};

mod leaky_bucket;
Expand Down Expand Up @@ -660,10 +660,9 @@ impl ObjectClient for MockClient {
&self,
bucket: &str,
key: &str,
range: Option<Range<u64>>,
if_match: Option<ETag>,
params: &GetObjectParams,
) -> ObjectClientResult<Self::GetObjectRequest, GetObjectError, Self::ClientError> {
trace!(bucket, key, ?range, ?if_match, "GetObject");
trace!(bucket, key, ?params.range, ?params.if_match, "GetObject");
self.inc_op_count(Operation::GetObject);

if bucket != self.config.bucket {
Expand All @@ -673,13 +672,13 @@ impl ObjectClient for MockClient {
let objects = self.objects.read().unwrap();

if let Some(object) = objects.get(key) {
if let Some(etag_match) = if_match {
if etag_match != object.etag {
if let Some(etag_match) = params.if_match.as_ref() {
if etag_match != &object.etag {
return Err(ObjectClientError::ServiceError(GetObjectError::PreconditionFailed));
}
}

let (next_offset, length) = if let Some(range) = range {
let (next_offset, length) = if let Some(range) = params.range.as_ref() {
if range.start >= object.len() as u64 || range.end > object.len() as u64 {
return mock_client_error(format!("invalid range, length={}", object.len()));
}
Expand Down Expand Up @@ -1050,6 +1049,7 @@ mod tests {
use futures::{pin_mut, StreamExt};
use rand::{Rng, RngCore, SeedableRng};
use rand_chacha::ChaChaRng;
use std::ops::Range;
use test_case::test_case;

use super::*;
Expand Down Expand Up @@ -1089,7 +1089,7 @@ mod tests {
client.add_object(key, object);

let mut get_request = client
.get_object("test_bucket", key, range.clone(), None)
.get_object("test_bucket", key, &GetObjectParams::new().range(range.clone()))
.await
.expect("should not fail");

Expand Down Expand Up @@ -1143,7 +1143,7 @@ mod tests {
client.add_object(key, MockObject::from_bytes(&body, ETag::for_tests()));

let get_request = client
.get_object("test_bucket", key, range.clone(), None)
.get_object("test_bucket", key, &GetObjectParams::new().range(range.clone()))
.await
.expect("should not fail");
pin_mut!(get_request);
Expand Down Expand Up @@ -1191,33 +1191,45 @@ mod tests {
client.add_object("key1", body[..].into());

assert!(matches!(
client.get_object("wrong_bucket", "key1", None, None).await,
client.get_object("wrong_bucket", "key1", &GetObjectParams::new()).await,
Err(ObjectClientError::ServiceError(GetObjectError::NoSuchBucket))
));

assert!(matches!(
client.get_object("test_bucket", "wrong_key", None, None).await,
client
.get_object("test_bucket", "wrong_key", &GetObjectParams::new())
.await,
Err(ObjectClientError::ServiceError(GetObjectError::NoSuchKey))
));

assert_client_error!(
client.get_object("test_bucket", "key1", Some(0..2001), None).await,
client
.get_object("test_bucket", "key1", &GetObjectParams::new().range(Some(0..2001)))
.await,
"invalid range, length=2000"
);
assert_client_error!(
client.get_object("test_bucket", "key1", Some(2000..2000), None).await,
client
.get_object("test_bucket", "key1", &GetObjectParams::new().range(Some(2000..2000)))
.await,
"invalid range, length=2000"
);
assert_client_error!(
client.get_object("test_bucket", "key1", Some(500..2001), None).await,
client
.get_object("test_bucket", "key1", &GetObjectParams::new().range(Some(500..2001)))
.await,
"invalid range, length=2000"
);
assert_client_error!(
client.get_object("test_bucket", "key1", Some(5000..2001), None).await,
client
.get_object("test_bucket", "key1", &GetObjectParams::new().range(Some(5000..2001)))
.await,
"invalid range, length=2000"
);
assert_client_error!(
client.get_object("test_bucket", "key1", Some(5000..1), None).await,
client
.get_object("test_bucket", "key1", &GetObjectParams::new().range(Some(5000..1)))
.await,
"invalid range, length=2000"
);
}
Expand Down Expand Up @@ -1245,7 +1257,7 @@ mod tests {
client.add_object(key, MockObject::from_bytes(&expected_body, ETag::for_tests()));

let mut get_request = client
.get_object("test_bucket", key, Some(range.clone()), None)
.get_object("test_bucket", key, &GetObjectParams::new().range(Some(range.clone())))
.await
.expect("should not fail");

Expand Down Expand Up @@ -1282,7 +1294,7 @@ mod tests {
.expect("Should not fail");

client
.get_object(bucket, dst_key, None, None)
.get_object(bucket, dst_key, &GetObjectParams::new())
.await
.expect("get_object should succeed");
}
Expand Down Expand Up @@ -1706,7 +1718,7 @@ mod tests {
put_request.complete().await.expect("put_object failed");

let mut get_request = client
.get_object("test_bucket", "key1", None, None)
.get_object("test_bucket", "key1", &GetObjectParams::new())
.await
.expect("get_object failed");

Expand Down Expand Up @@ -1739,7 +1751,7 @@ mod tests {
.expect("put_object failed");

let get_request = client
.get_object("test_bucket", "key1", None, None)
.get_object("test_bucket", "key1", &GetObjectParams::new())
.await
.expect("get_object failed");

Expand Down
21 changes: 12 additions & 9 deletions mountpoint-s3-client/src/mock_client/throughput_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::ops::Range;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
Expand All @@ -14,10 +13,11 @@ use crate::mock_client::{
MockClient, MockClientConfig, MockClientError, MockGetObjectRequest, MockObject, MockPutObjectRequest,
};
use crate::object_client::{
CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, ETag, GetBodyPart,
GetObjectAttributesError, GetObjectAttributesResult, GetObjectError, GetObjectRequest, HeadObjectError,
HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute, ObjectClient,
ObjectClientResult, ObjectMetadata, PutObjectError, PutObjectParams, PutObjectResult, PutObjectSingleParams,
CopyObjectError, CopyObjectParams, CopyObjectResult, DeleteObjectError, DeleteObjectResult, GetBodyPart,
GetObjectAttributesError, GetObjectAttributesResult, GetObjectError, GetObjectParams, GetObjectRequest,
HeadObjectError, HeadObjectParams, HeadObjectResult, ListObjectsError, ListObjectsResult, ObjectAttribute,
ObjectClient, ObjectClientResult, ObjectMetadata, PutObjectError, PutObjectParams, PutObjectResult,
PutObjectSingleParams,
};

/// A [MockClient] that rate limits overall download throughput to simulate a target network
Expand Down Expand Up @@ -148,10 +148,9 @@ impl ObjectClient for ThroughputMockClient {
&self,
bucket: &str,
key: &str,
range: Option<Range<u64>>,
if_match: Option<ETag>,
params: &GetObjectParams,
) -> ObjectClientResult<Self::GetObjectRequest, GetObjectError, Self::ClientError> {
let request = self.inner.get_object(bucket, key, range, if_match).await?;
let request = self.inner.get_object(bucket, key, params).await?;
let rate_limiter = self.rate_limiter.clone();
Ok(ThroughputGetObjectRequest { request, rate_limiter })
}
Expand Down Expand Up @@ -219,6 +218,7 @@ mod tests {
use futures::StreamExt;

use crate::mock_client::MockObject;
use crate::types::ETag;

use super::*;

Expand All @@ -245,7 +245,10 @@ mod tests {
let start = Instant::now();
let num_bytes = block_on(async move {
let mut num_bytes = 0;
let mut get = client.get_object("test_bucket", "testfile", None, None).await.unwrap();
let mut get = client
.get_object("test_bucket", "testfile", &GetObjectParams::new())
.await
.unwrap();
while let Some(part) = get.next().await {
let (_offset, part) = part.unwrap();
num_bytes += part.len();
Expand Down
Loading

0 comments on commit 36d386e

Please sign in to comment.