From 1ea6143519bed716303b2f98c94125ae8b2e57d1 Mon Sep 17 00:00:00 2001
From: Alessandro Passaro <alexpax@amazon.co.uk>
Date: Thu, 5 Oct 2023 07:47:26 +0000
Subject: [PATCH 1/4] Move ChecksummedBytes into a new module

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>
---
 .../checksummed_bytes.rs => checksums.rs}     | 49 ++++++++++++++++---
 mountpoint-s3/src/data_cache.rs               |  2 +-
 .../src/data_cache/in_memory_data_cache.rs    | 16 +++---
 mountpoint-s3/src/lib.rs                      |  1 +
 mountpoint-s3/src/prefetch.rs                 |  3 +-
 mountpoint-s3/src/prefetch/feed.rs            |  3 +-
 mountpoint-s3/src/prefetch/part.rs            |  2 +-
 mountpoint-s3/src/prefetch/part_queue.rs      |  2 +-
 mountpoint-s3/src/upload.rs                   | 21 +-------
 9 files changed, 58 insertions(+), 41 deletions(-)
 rename mountpoint-s3/src/{prefetch/checksummed_bytes.rs => checksums.rs} (87%)

diff --git a/mountpoint-s3/src/prefetch/checksummed_bytes.rs b/mountpoint-s3/src/checksums.rs
similarity index 87%
rename from mountpoint-s3/src/prefetch/checksummed_bytes.rs
rename to mountpoint-s3/src/checksums.rs
index 4f9592326..4c262b6a4 100644
--- a/mountpoint-s3/src/prefetch/checksummed_bytes.rs
+++ b/mountpoint-s3/src/checksums.rs
@@ -1,9 +1,10 @@
+use std::ops::RangeBounds;
+
 use bytes::{Bytes, BytesMut};
 use mountpoint_s3_crt::checksums::crc32c::{self, Crc32c};
 use thiserror::Error;
 
 /// A `ChecksummedBytes` is a bytes buffer that carries its checksum.
-/// The implementation guarantees that its integrity will be validated when data transformation occurs.
 #[derive(Clone, Debug)]
 pub struct ChecksummedBytes {
     orig_bytes: Bytes,
@@ -22,6 +23,12 @@ impl ChecksummedBytes {
         }
     }
 
+    /// Create [ChecksummedBytes] from [Bytes], calculating its checksum.
+    pub fn from_bytes(bytes: Bytes) -> Self {
+        let checksum = crc32c::checksum(&bytes);
+        Self::new(bytes, checksum)
+    }
+
     /// Convert the `ChecksummedBytes` into `Bytes`, data integrity will be validated before converting.
     ///
     /// Return `IntegrityError` on data corruption.
@@ -119,6 +126,27 @@ impl Default for ChecksummedBytes {
     }
 }
 
+impl From<Bytes> for ChecksummedBytes {
+    fn from(value: Bytes) -> Self {
+        Self::from_bytes(value)
+    }
+}
+
+impl TryFrom<ChecksummedBytes> for Bytes {
+    type Error = IntegrityError;
+
+    fn try_from(value: ChecksummedBytes) -> Result<Self, Self::Error> {
+        value.into_bytes()
+    }
+}
+
+/// Calculates the combined checksum for `AB` where `prefix_crc` is the checksum for `A`,
+/// `suffix_crc` is the checksum for `B`, and `suffix_len` is the length of `B`.
+pub fn combine_checksums(prefix_crc: Crc32c, suffix_crc: Crc32c, suffix_len: usize) -> Crc32c {
+    let combined = ::crc32c::crc32c_combine(prefix_crc.value(), suffix_crc.value(), suffix_len);
+    Crc32c::new(combined)
+}
+
 #[derive(Debug, Error)]
 pub enum IntegrityError {
     #[error("Checksum mismatch. expected: {0:?}, actual: {1:?}")]
@@ -146,19 +174,15 @@ impl PartialEq for ChecksummedBytes {
 
 #[cfg(test)]
 mod tests {
-    use bytes::Bytes;
     use mountpoint_s3_crt::checksums::crc32c;
 
-    use crate::prefetch::checksummed_bytes::IntegrityError;
-
-    use super::ChecksummedBytes;
+    use super::*;
 
     #[test]
     fn test_into_bytes() {
         let bytes = Bytes::from_static(b"some bytes");
         let expected = bytes.clone();
-        let checksum = crc32c::checksum(&bytes);
-        let checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
+        let checksummed_bytes = ChecksummedBytes::from_bytes(bytes);
 
         let actual = checksummed_bytes.into_bytes().unwrap();
         assert_eq!(expected, actual);
@@ -262,4 +286,15 @@ mod tests {
         let result = checksummed_bytes.extend(extend);
         assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
     }
+
+    #[test]
+    fn test_combine_checksums() {
+        let buf: &[u8] = b"123456789";
+        let (buf1, buf2) = buf.split_at(4);
+        let crc = crc32c::checksum(buf);
+        let crc1 = crc32c::checksum(buf1);
+        let crc2 = crc32c::checksum(buf2);
+        let combined = combine_checksums(crc1, crc2, buf2.len());
+        assert_eq!(combined, crc);
+    }
 }
diff --git a/mountpoint-s3/src/data_cache.rs b/mountpoint-s3/src/data_cache.rs
index 8ea85a00a..eaa52f88a 100644
--- a/mountpoint-s3/src/data_cache.rs
+++ b/mountpoint-s3/src/data_cache.rs
@@ -10,7 +10,7 @@ use std::ops::Range;
 
 use thiserror::Error;
 
-pub use crate::prefetch::checksummed_bytes::ChecksummedBytes;
+pub use crate::checksums::ChecksummedBytes;
 
 /// Indexes blocks within a given object.
 pub type BlockIndex = u64;
diff --git a/mountpoint-s3/src/data_cache/in_memory_data_cache.rs b/mountpoint-s3/src/data_cache/in_memory_data_cache.rs
index 6843b44e2..0cb275fde 100644
--- a/mountpoint-s3/src/data_cache/in_memory_data_cache.rs
+++ b/mountpoint-s3/src/data_cache/in_memory_data_cache.rs
@@ -59,20 +59,18 @@ mod tests {
 
     use bytes::Bytes;
 
-    use mountpoint_s3_crt::checksums::crc32c;
-
     type TestCacheKey = String;
 
     #[test]
     fn test_put_get() {
         let data_1 = Bytes::from_static(b"Hello world");
-        let data_1 = ChecksummedBytes::new(data_1.clone(), crc32c::checksum(&data_1));
+        let data_1 = ChecksummedBytes::from_bytes(data_1.clone());
         let data_2 = Bytes::from_static(b"Foo bar");
-        let data_2 = ChecksummedBytes::new(data_2.clone(), crc32c::checksum(&data_2));
+        let data_2 = ChecksummedBytes::from_bytes(data_2.clone());
         let data_3 = Bytes::from_static(b"Baz");
-        let data_3 = ChecksummedBytes::new(data_3.clone(), crc32c::checksum(&data_3));
+        let data_3 = ChecksummedBytes::from_bytes(data_3.clone());
 
-        let mut cache = InMemoryDataCache::new(8 * 1024 * 1024);
+        let cache = InMemoryDataCache::new(8 * 1024 * 1024);
         let cache_key_1: TestCacheKey = String::from("a");
         let cache_key_2: TestCacheKey = String::from("b");
 
@@ -136,11 +134,11 @@ mod tests {
     #[test]
     fn test_cached_indices() {
         let data_1 = Bytes::from_static(b"Hello world");
-        let data_1 = ChecksummedBytes::new(data_1.clone(), crc32c::checksum(&data_1));
+        let data_1 = ChecksummedBytes::from_bytes(data_1.clone());
         let data_2 = Bytes::from_static(b"Foo bar");
-        let data_2 = ChecksummedBytes::new(data_2.clone(), crc32c::checksum(&data_2));
+        let data_2 = ChecksummedBytes::from_bytes(data_2.clone());
 
-        let mut cache = InMemoryDataCache::new(8 * 1024 * 1024);
+        let cache = InMemoryDataCache::new(8 * 1024 * 1024);
         let cache_key_1: TestCacheKey = String::from("a");
         let cache_key_2: TestCacheKey = String::from("b");
 
diff --git a/mountpoint-s3/src/lib.rs b/mountpoint-s3/src/lib.rs
index d4550f489..a17cb14b3 100644
--- a/mountpoint-s3/src/lib.rs
+++ b/mountpoint-s3/src/lib.rs
@@ -1,3 +1,4 @@
+mod checksums;
 mod data_cache;
 pub mod fs;
 pub mod fuse;
diff --git a/mountpoint-s3/src/prefetch.rs b/mountpoint-s3/src/prefetch.rs
index b4717bc0b..550aa229b 100644
--- a/mountpoint-s3/src/prefetch.rs
+++ b/mountpoint-s3/src/prefetch.rs
@@ -7,7 +7,6 @@
 //! we increase the size of the GetObject requests up to some maximum. If the reader ever makes a
 //! non-sequential read, we abandon the prefetching and start again with the minimum request size.
 
-pub mod checksummed_bytes;
 mod feed;
 mod part;
 mod part_queue;
@@ -26,7 +25,7 @@ use mountpoint_s3_client::ObjectClient;
 use thiserror::Error;
 use tracing::{debug_span, error, trace, Instrument};
 
-use crate::prefetch::checksummed_bytes::{ChecksummedBytes, IntegrityError};
+use crate::checksums::{ChecksummedBytes, IntegrityError};
 use crate::prefetch::feed::{ClientPartFeed, ObjectPartFeed};
 use crate::prefetch::part::Part;
 use crate::prefetch::part_queue::{unbounded_part_queue, PartQueue};
diff --git a/mountpoint-s3/src/prefetch/feed.rs b/mountpoint-s3/src/prefetch/feed.rs
index 0a7cf2819..b6067cafe 100644
--- a/mountpoint-s3/src/prefetch/feed.rs
+++ b/mountpoint-s3/src/prefetch/feed.rs
@@ -11,7 +11,8 @@ use mountpoint_s3_client::{
 use mountpoint_s3_crt::checksums::crc32c;
 use tracing::{error, trace};
 
-use crate::prefetch::{checksummed_bytes::ChecksummedBytes, part::Part, part_queue::PartQueueProducer};
+use crate::checksums::ChecksummedBytes;
+use crate::prefetch::{part::Part, part_queue::PartQueueProducer};
 
 /// A generic interface to retrieve data from objects in a S3-like store.
 #[async_trait]
diff --git a/mountpoint-s3/src/prefetch/part.rs b/mountpoint-s3/src/prefetch/part.rs
index 8a8af6b25..c3a7b34c0 100644
--- a/mountpoint-s3/src/prefetch/part.rs
+++ b/mountpoint-s3/src/prefetch/part.rs
@@ -1,6 +1,6 @@
 use thiserror::Error;
 
-use super::checksummed_bytes::ChecksummedBytes;
+use crate::checksums::ChecksummedBytes;
 
 /// A self-identifying part of an S3 object. Users can only retrieve the bytes from this part if
 /// they can prove they have the correct offset and key.
diff --git a/mountpoint-s3/src/prefetch/part_queue.rs b/mountpoint-s3/src/prefetch/part_queue.rs
index 42355eff5..07a353bab 100644
--- a/mountpoint-s3/src/prefetch/part_queue.rs
+++ b/mountpoint-s3/src/prefetch/part_queue.rs
@@ -99,7 +99,7 @@ impl<E: std::error::Error + Send + Sync> PartQueueProducer<E> {
 
 #[cfg(test)]
 mod tests {
-    use crate::prefetch::checksummed_bytes::ChecksummedBytes;
+    use crate::checksums::ChecksummedBytes;
 
     use super::*;
 
diff --git a/mountpoint-s3/src/upload.rs b/mountpoint-s3/src/upload.rs
index 10cb14b10..c12a1638f 100644
--- a/mountpoint-s3/src/upload.rs
+++ b/mountpoint-s3/src/upload.rs
@@ -9,6 +9,8 @@ use mountpoint_s3_crt::checksums::crc32c::{Crc32c, Hasher};
 use thiserror::Error;
 use tracing::error;
 
+use crate::checksums::combine_checksums;
+
 type PutRequestError<Client> = ObjectClientError<PutObjectError, <Client as ObjectClient>::ClientError>;
 
 const MAX_S3_MULTIPART_UPLOAD_PARTS: usize = 10000;
@@ -180,13 +182,6 @@ fn verify_checksums(review: UploadReview, expected_size: u64, expected_checksum:
     true
 }
 
-/// Calculates the combined checksum for `AB` where `prefix_crc` is the checksum for `A`,
-/// `suffix_crc` is the checksum for `B`, and `suffic_len` is the length of `B`.
-fn combine_checksums(prefix_crc: Crc32c, suffix_crc: Crc32c, suffix_len: usize) -> Crc32c {
-    let combined = ::crc32c::crc32c_combine(prefix_crc.value(), suffix_crc.value(), suffix_len);
-    Crc32c::new(combined)
-}
-
 #[cfg(test)]
 mod tests {
     use std::collections::HashMap;
@@ -196,7 +191,6 @@ mod tests {
         failure_client::countdown_failure_client,
         mock_client::{MockClient, MockClientConfig, MockClientError},
     };
-    use mountpoint_s3_crt::checksums::crc32c;
     use test_case::test_case;
 
     #[tokio::test]
@@ -341,15 +335,4 @@ mod tests {
         assert!(!client.contains_key(key));
         assert!(!client.is_upload_in_progress(key));
     }
-
-    #[test]
-    fn test_combine_checksums() {
-        let buf: &[u8] = b"123456789";
-        let (buf1, buf2) = buf.split_at(4);
-        let crc = crc32c::checksum(buf);
-        let crc1 = crc32c::checksum(buf1);
-        let crc2 = crc32c::checksum(buf2);
-        let combined = combine_checksums(crc1, crc2, buf2.len());
-        assert_eq!(combined, crc);
-    }
 }

From 711d37ddda701ecd181b3a9586ff0c80dfd6e7f6 Mon Sep 17 00:00:00 2001
From: Alessandro Passaro <alexpax@amazon.co.uk>
Date: Wed, 25 Oct 2023 15:58:47 +0100
Subject: [PATCH 2/4] Improve ChecksummedBytes::extend and clarify data
 integrity guarantee

ChecksummedBytes maintains a data buffer and a checksum and guarantees that only validated data can be accessed. Transformations such as `split_off`, `extend`, or `slice` (introduced in this change), may trigger a validation (and return an IntegrityError on failure), or propagate existing checksum(s) if possible, allowing for later validation.

This change clarifies the data integrity guarantee in the docs and optimizes the extend method to avoid re-validation when the checksums for both slices can be combined. It also avoid a redundant buffer allocation.

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>
---
 mountpoint-s3/src/checksums.rs | 265 +++++++++++++++++++++++----------
 1 file changed, 189 insertions(+), 76 deletions(-)

diff --git a/mountpoint-s3/src/checksums.rs b/mountpoint-s3/src/checksums.rs
index 4c262b6a4..98067a98c 100644
--- a/mountpoint-s3/src/checksums.rs
+++ b/mountpoint-s3/src/checksums.rs
@@ -2,9 +2,13 @@ use std::ops::RangeBounds;
 
 use bytes::{Bytes, BytesMut};
 use mountpoint_s3_crt::checksums::crc32c::{self, Crc32c};
+
 use thiserror::Error;
 
 /// A `ChecksummedBytes` is a bytes buffer that carries its checksum.
+/// The implementation guarantees that integrity will be validated before the data can be accessed.
+/// Data transformations will either fail returning an [IntegrityError], or propagate the checksum
+/// so that it can be validated on access.
 #[derive(Clone, Debug)]
 pub struct ChecksummedBytes {
     orig_bytes: Bytes,
@@ -29,21 +33,21 @@ impl ChecksummedBytes {
         Self::new(bytes, checksum)
     }
 
-    /// Convert the `ChecksummedBytes` into `Bytes`, data integrity will be validated before converting.
+    /// Convert the [ChecksummedBytes] into [Bytes], data integrity will be validated before converting.
     ///
-    /// Return `IntegrityError` on data corruption.
+    /// Return [IntegrityError] on data corruption.
     pub fn into_bytes(self) -> Result<Bytes, IntegrityError> {
         self.validate()?;
 
         Ok(self.curr_slice)
     }
 
-    /// Returns the number of bytes contained in this `ChecksummedBytes`.
+    /// Returns the number of bytes contained in this [ChecksummedBytes].
     pub fn len(&self) -> usize {
         self.curr_slice.len()
     }
 
-    /// Returns true if the `ChecksummedBytes` has a length of 0.
+    /// Returns true if the [ChecksummedBytes] has a length of 0.
     pub fn is_empty(&self) -> bool {
         self.curr_slice.is_empty()
     }
@@ -63,47 +67,77 @@ impl ChecksummedBytes {
         }
     }
 
-    /// Append the given checksummed bytes to current `ChecksummedBytes`, ensure that data integrity will
-    /// be validated.
+    /// Returns a slice of self for the provided range.
+    ///
+    /// This operation just increases the reference count and sets a few indices,
+    /// so there will be no validation and the checksum will not be recomputed.
+    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
+        Self {
+            orig_bytes: self.orig_bytes.clone(),
+            curr_slice: self.curr_slice.slice(range),
+            checksum: self.checksum,
+        }
+    }
+
+    /// Returns a copy of this slice, with the guarantee that the checksum is computed exactly
+    /// on the slice, rather than on a larger containing buffer.
+    ///
+    /// Return [IntegrityError] if data corruption is detected.
+    pub fn shrink_to_fit(&self) -> Result<Self, IntegrityError> {
+        if self.curr_slice.len() == self.orig_bytes.len() {
+            return Ok(self.clone());
+        }
+
+        let result = Self::from_bytes(self.curr_slice.clone());
+        self.validate()?;
+        Ok(result)
+    }
+
+    /// Append the given checksummed bytes to current [ChecksummedBytes]. Will combine the
+    /// existing checksums if possible, or compute a new one and validate data integrity.
     ///
-    /// Return `IntegrityError` on data corruption.
+    /// Return [IntegrityError] if data corruption is detected.
     pub fn extend(&mut self, extend: ChecksummedBytes) -> Result<(), IntegrityError> {
-        let curr_len = self.curr_slice.len();
-        let total_len = curr_len + extend.len();
-
-        let mut bytes_mut = BytesMut::with_capacity(total_len);
-        bytes_mut.extend_from_slice(&self.curr_slice);
-        bytes_mut.extend_from_slice(&extend.curr_slice);
-        let new_bytes = bytes_mut.freeze();
-        let new_checksum = crc32c::checksum(&new_bytes);
-        let new_checksummed_bytes = ChecksummedBytes::new(new_bytes, new_checksum);
-
-        // Validate data integrity with checksum bracketing.
-        {
-            // 1. repeat the operation, which means copying into a new buffer in this case.
-            let mut bytes_mut_dup = BytesMut::with_capacity(total_len);
-            bytes_mut_dup.extend_from_slice(&self.curr_slice);
-            bytes_mut_dup.extend_from_slice(&extend.curr_slice);
-            let new_bytes_dup = bytes_mut_dup.freeze();
-            let new_checksum_dup = crc32c::checksum(&new_bytes_dup);
-
-            // 2. compare the checksum between the two transformations.
-            if new_checksum != new_checksum_dup {
-                return Err(IntegrityError::ChecksumMismatch(new_checksum, new_checksum_dup));
-            }
-
-            // 3. validate original buffers to make sure that the data we have copied are still valid.
-            self.validate()?;
+        if extend.is_empty() {
+            // No op, but check that `extend` was not corrupted
             extend.validate()?;
+            return Ok(());
+        }
+
+        if self.is_empty() {
+            // Replace with `extend`, but check that `self` was not corrupted
+            self.validate()?;
+            *self = extend;
+            return Ok(());
         }
 
-        *self = new_checksummed_bytes;
+        // When appending two slices, we can combine their checksums and obtain the new checksum
+        // without having to recompute it from the data.
+        // However, since a `ChecksummedBytes` potentially holds the checksum of some larger buffer,
+        // rather than the exact one for the slice, we need to first invoke `shrink_to_fit` on each
+        // slice and use the resulting exact checksums.
+        let prefix = self.shrink_to_fit()?;
+        assert_eq!(prefix.orig_bytes.len(), prefix.curr_slice.len());
+        let suffix = extend.shrink_to_fit()?;
+        assert_eq!(suffix.orig_bytes.len(), suffix.curr_slice.len());
+
+        // Combine the checksums.
+        let new_checksum = combine_checksums(prefix.checksum, suffix.checksum, suffix.len());
+
+        // Combine the slices.
+        let new_bytes = {
+            let mut bytes_mut = BytesMut::with_capacity(prefix.len() + suffix.len());
+            bytes_mut.extend_from_slice(&prefix.curr_slice);
+            bytes_mut.extend_from_slice(&suffix.curr_slice);
+            bytes_mut.freeze()
+        };
+        *self = ChecksummedBytes::new(new_bytes, new_checksum);
         Ok(())
     }
 
-    /// Validate data integrity in this `ChecksummedBytes`.
+    /// Validate data integrity in this [ChecksummedBytes].
     ///
-    /// Return `IntegrityError` on data corruption.
+    /// Return [IntegrityError] on data corruption.
     pub fn validate(&self) -> Result<(), IntegrityError> {
         let checksum = crc32c::checksum(&self.orig_bytes);
         if self.checksum != checksum {
@@ -161,14 +195,10 @@ impl PartialEq for ChecksummedBytes {
             return false;
         }
 
-        if self.orig_bytes == other.orig_bytes && self.checksum == other.checksum {
-            return true;
-        }
-
+        let result = self.orig_bytes == other.orig_bytes && self.checksum == other.checksum;
         self.validate().expect("should be valid");
         other.validate().expect("should be valid");
-
-        true
+        result
     }
 }
 
@@ -192,8 +222,7 @@ mod tests {
     fn test_into_bytes_integrity_error() {
         let bytes = Bytes::from_static(b"some bytes");
         let checksum = crc32c::checksum(&bytes);
-        let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
-        checksummed_bytes.orig_bytes = Bytes::from_static(b"new bytes");
+        let checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"new bytes"), checksum);
 
         let actual = checksummed_bytes.into_bytes();
         assert!(matches!(actual, Err(IntegrityError::ChecksumMismatch(_, _))));
@@ -220,16 +249,69 @@ mod tests {
     }
 
     #[test]
-    fn test_extend() {
+    fn test_slice() {
+        let range = 3..7;
         let bytes = Bytes::from_static(b"some bytes");
-        let expected = Bytes::from_static(b"some bytes extended");
+        let expected = bytes.clone();
+        let expected_slice = bytes.slice(range.clone());
         let checksum = crc32c::checksum(&bytes);
-        let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
+        let original = ChecksummedBytes::new(bytes, checksum);
+        let slice = original.slice(range);
+
+        assert_eq!(expected, original.orig_bytes);
+        assert_eq!(expected, original.curr_slice);
+        assert_eq!(expected, slice.orig_bytes);
+        assert_eq!(expected_slice, slice.curr_slice);
+        assert_eq!(checksum, original.checksum);
+        assert_eq!(checksum, slice.checksum);
+    }
 
-        let extend = Bytes::from_static(b" extended");
-        let extend_checksum = crc32c::checksum(&extend);
-        let extend = ChecksummedBytes::new(extend, extend_checksum);
-        checksummed_bytes.extend(extend).unwrap();
+    #[test]
+    fn test_shrink_to_fit() {
+        let original = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
+        let unchanged = original.shrink_to_fit().unwrap();
+        assert_eq!(original.curr_slice, unchanged.curr_slice);
+        assert_eq!(original.orig_bytes, unchanged.orig_bytes);
+        assert_eq!(original.checksum, unchanged.checksum);
+
+        let slice = original.clone().split_off(5);
+        let shrunken = slice.shrink_to_fit().unwrap();
+        assert_eq!(slice.curr_slice, shrunken.curr_slice);
+        assert_ne!(slice.orig_bytes, shrunken.orig_bytes);
+        assert_ne!(slice.checksum, shrunken.checksum);
+    }
+
+    #[test]
+    fn test_shrink_to_fit_corrupted() {
+        let checksum = crc32c::checksum(b"some bytes");
+        let original = ChecksummedBytes::new(Bytes::from_static(b"other bytes"), checksum);
+        assert!(matches!(
+            original.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+
+        let unchanged = original.shrink_to_fit().unwrap();
+        assert_eq!(original.curr_slice, unchanged.curr_slice);
+        assert_eq!(original.orig_bytes, unchanged.orig_bytes);
+        assert_eq!(original.checksum, unchanged.checksum);
+        assert!(matches!(
+            unchanged.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+
+        let slice = original.clone().split_off(5);
+        assert!(matches!(slice.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
+
+        let result = slice.shrink_to_fit();
+        assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
+    }
+
+    #[test]
+    fn test_extend() {
+        let expected = Bytes::from_static(b"some bytes extended");
+        let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
+        let extend_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
+        checksummed_bytes.extend(extend_bytes).unwrap();
         let actual = checksummed_bytes.curr_slice;
         assert_eq!(expected, actual);
     }
@@ -237,14 +319,10 @@ mod tests {
     #[test]
     fn test_extend_after_split() {
         let split_off_at = 4;
-        let bytes = Bytes::from_static(b"some bytes");
-        let expected = Bytes::from_static(b"some ext");
-        let checksum = crc32c::checksum(&bytes);
-        let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
 
-        let extend = Bytes::from_static(b" extended");
-        let extend_checksum = crc32c::checksum(&extend);
-        let mut extend = ChecksummedBytes::new(extend, extend_checksum);
+        let expected = Bytes::from_static(b"some ext");
+        let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
+        let mut extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
         checksummed_bytes.split_off(split_off_at);
         extend.split_off(split_off_at);
         checksummed_bytes.extend(extend).unwrap();
@@ -254,34 +332,69 @@ mod tests {
 
     #[test]
     fn test_extend_self_corrupted() {
-        let bytes = Bytes::from_static(b"some bytes");
-        let checksum = crc32c::checksum(&bytes);
-        let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
+        let corrupted_bytes = Bytes::from_static(b"corrupted data");
+        let checksum = crc32c::checksum(b"some bytes");
+        let mut checksummed_bytes = ChecksummedBytes::new(corrupted_bytes, checksum);
+        assert!(matches!(
+            checksummed_bytes.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+
+        let extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
+        assert!(matches!(extend.validate(), Ok(())));
 
-        let currupted_bytes = Bytes::from_static(b"corrupted data");
-        checksummed_bytes.orig_bytes = currupted_bytes.clone();
-        checksummed_bytes.curr_slice = currupted_bytes;
+        checksummed_bytes.extend(extend).unwrap();
+        assert!(matches!(
+            checksummed_bytes.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+    }
+
+    #[test]
+    fn test_extend_after_split_self_corrupted() {
+        let corrupted_bytes = Bytes::from_static(b"corrupted data");
+        let checksum = crc32c::checksum(b"some bytes");
+        let mut checksummed_bytes = ChecksummedBytes::new(corrupted_bytes, checksum);
+        assert!(matches!(
+            checksummed_bytes.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+        checksummed_bytes.split_off(4);
+
+        let extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
+        assert!(matches!(extend.validate(), Ok(())));
 
-        let extend = Bytes::from_static(b" extended");
-        let extend_checksum = crc32c::checksum(&extend);
-        let extend = ChecksummedBytes::new(extend, extend_checksum);
         let result = checksummed_bytes.extend(extend);
         assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
     }
 
     #[test]
     fn test_extend_other_corrupted() {
-        let bytes = Bytes::from_static(b"some bytes");
-        let checksum = crc32c::checksum(&bytes);
-        let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
+        let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
+        assert!(matches!(checksummed_bytes.validate(), Ok(())));
 
-        let extend = Bytes::from_static(b" extended");
-        let extend_checksum = crc32c::checksum(&extend);
-        let mut extend = ChecksummedBytes::new(extend, extend_checksum);
+        let corrupted_bytes = Bytes::from_static(b"corrupted data");
+        let extend_checksum = crc32c::checksum(b" extended");
+        let extend = ChecksummedBytes::new(corrupted_bytes, extend_checksum);
+        assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
 
-        let currupted_bytes = Bytes::from_static(b"corrupted data");
-        extend.orig_bytes = currupted_bytes.clone();
-        extend.curr_slice = currupted_bytes;
+        checksummed_bytes.extend(extend).unwrap();
+        assert!(matches!(
+            checksummed_bytes.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+    }
+
+    #[test]
+    fn test_extend_after_split_other_corrupted() {
+        let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
+        assert!(matches!(checksummed_bytes.validate(), Ok(())));
+
+        let corrupted_bytes = Bytes::from_static(b"corrupted data");
+        let extend_checksum = crc32c::checksum(b" extended");
+        let mut extend = ChecksummedBytes::new(corrupted_bytes, extend_checksum);
+        extend.split_off(4);
+        assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
 
         let result = checksummed_bytes.extend(extend);
         assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));

From 02ce1c2bac879eca3204ecc0ab7b59105a58a93d Mon Sep 17 00:00:00 2001
From: Alessandro Passaro <alexpax@amazon.co.uk>
Date: Thu, 26 Oct 2023 18:42:37 +0100
Subject: [PATCH 3/4] Add split_off tests

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>
---
 mountpoint-s3/src/checksums.rs | 34 ++++++++++++++++++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/mountpoint-s3/src/checksums.rs b/mountpoint-s3/src/checksums.rs
index 98067a98c..6e24019c0 100644
--- a/mountpoint-s3/src/checksums.rs
+++ b/mountpoint-s3/src/checksums.rs
@@ -368,6 +368,23 @@ mod tests {
         assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
     }
 
+    #[test]
+    fn test_extend_split_off_self_corrupted() {
+        let corrupted_bytes = Bytes::from_static(b"corrupted data");
+        let checksum = crc32c::checksum(b"some bytes");
+        let mut split_off = ChecksummedBytes::new(corrupted_bytes, checksum).split_off(4);
+        assert!(matches!(
+            split_off.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+
+        let extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
+        assert!(matches!(extend.validate(), Ok(())));
+
+        let result = split_off.extend(extend);
+        assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
+    }
+
     #[test]
     fn test_extend_other_corrupted() {
         let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
@@ -400,6 +417,23 @@ mod tests {
         assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
     }
 
+    #[test]
+    fn test_extend_split_off_other_corrupted() {
+        let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
+        assert!(matches!(checksummed_bytes.validate(), Ok(())));
+
+        let corrupted_bytes = Bytes::from_static(b"corrupted data");
+        let extend_checksum = crc32c::checksum(b" extended");
+        let split_off = ChecksummedBytes::new(corrupted_bytes, extend_checksum).split_off(4);
+        assert!(matches!(
+            split_off.validate(),
+            Err(IntegrityError::ChecksumMismatch(_, _))
+        ));
+
+        let result = checksummed_bytes.extend(split_off);
+        assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
+    }
+
     #[test]
     fn test_combine_checksums() {
         let buf: &[u8] = b"123456789";

From 503a4ac13ee903365d0362a46a630e46f0ad9c47 Mon Sep 17 00:00:00 2001
From: Alessandro Passaro <alexpax@amazon.co.uk>
Date: Thu, 26 Oct 2023 18:48:14 +0100
Subject: [PATCH 4/4] Make clearer that shrink_to_fit does not copy data

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>
---
 mountpoint-s3/src/checksums.rs | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/mountpoint-s3/src/checksums.rs b/mountpoint-s3/src/checksums.rs
index 6e24019c0..cdc6310b4 100644
--- a/mountpoint-s3/src/checksums.rs
+++ b/mountpoint-s3/src/checksums.rs
@@ -12,6 +12,7 @@ use thiserror::Error;
 #[derive(Clone, Debug)]
 pub struct ChecksummedBytes {
     orig_bytes: Bytes,
+    /// Always a subslice of `orig_bytes`
     curr_slice: Bytes,
     /// Checksum for `orig_bytes`
     checksum: Crc32c,
@@ -88,7 +89,12 @@ impl ChecksummedBytes {
             return Ok(self.clone());
         }
 
-        let result = Self::from_bytes(self.curr_slice.clone());
+        // Note that no data is copied: `bytes` still points to a subslice of `orig_bytes`.
+        let bytes = self.curr_slice.clone();
+        let checksum = crc32c::checksum(&bytes);
+        let result = Self::new(bytes, checksum);
+
+        // Check the integrity of the whole buffer.
         self.validate()?;
         Ok(result)
     }