diff --git a/src/adler32.rs b/src/adler32.rs index 4ea0f1d4..3b997f4f 100644 --- a/src/adler32.rs +++ b/src/adler32.rs @@ -1,3 +1,5 @@ +use std::mem::MaybeUninit; + pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 { #[cfg(target_arch = "x86_64")] if std::is_x86_feature_detected!("avx2") { @@ -7,6 +9,25 @@ pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 { adler32_rust(start_checksum, data) } +pub fn adler32_fold_copy(start_checksum: u32, dst: &mut [MaybeUninit], src: &[u8]) -> u32 { + debug_assert!(dst.len() >= src.len(), "{} < {}", dst.len(), src.len()); + + #[cfg(target_arch = "x86_64")] + if std::is_x86_feature_detected!("avx2") { + return avx2::adler32_fold_copy_avx2(start_checksum, dst, src); + } + + let adler = adler32_rust(start_checksum, src); + dst[..src.len()].copy_from_slice(slice_to_uninit(src)); + adler +} + +// when stable, use MaybeUninit::write_slice +fn slice_to_uninit(slice: &[u8]) -> &[MaybeUninit] { + // safety: &[T] and &[MaybeUninit] have the same layout + unsafe { &*(slice as *const [u8] as *const [MaybeUninit]) } +} + // inefficient but correct, useful for testing #[cfg(test)] fn naive_adler32(start_checksum: u32, data: &[u8]) -> u32 { @@ -124,6 +145,25 @@ fn adler32_len_16(mut adler: u32, buf: &[u8], mut sum2: u32) -> u32 { adler | (sum2 << 16) } +fn adler32_copy_len_16( + mut adler: u32, + dst: &mut [MaybeUninit], + src: &[u8], + mut sum2: u32, +) -> u32 { + for (source, destination) in src.iter().zip(dst.iter_mut()) { + let v = *source; + *destination = MaybeUninit::new(v); + adler += v as u32; + sum2 += adler; + } + + adler %= BASE; + sum2 %= BASE; /* only added so many BASE's */ + /* return recombined sums */ + adler | (sum2 << 16) +} + fn adler32_len_64(mut adler: u32, buf: &[u8], mut sum2: u32) -> u32 { const N: usize = if UNROLL_MORE { 16 } else { 8 }; let mut it = buf.chunks_exact(N); @@ -145,8 +185,8 @@ mod avx2 { use std::arch::x86_64::{ __m256i, _mm256_add_epi32, _mm256_castsi256_si128, _mm256_extracti128_si256, _mm256_loadu_si256, _mm256_madd_epi16, _mm256_maddubs_epi16, _mm256_permutevar8x32_epi32, - _mm256_sad_epu8, _mm256_slli_epi32, _mm256_zextsi128_si256, _mm_add_epi32, - _mm_cvtsi128_si32, _mm_cvtsi32_si128, _mm_shuffle_epi32, _mm_unpackhi_epi64, + _mm256_sad_epu8, _mm256_slli_epi32, _mm256_storeu_si256, _mm256_zextsi128_si256, + _mm_add_epi32, _mm_cvtsi128_si32, _mm_cvtsi32_si128, _mm_shuffle_epi32, _mm_unpackhi_epi64, }; const fn __m256i_literal(bytes: [u8; 32]) -> __m256i { @@ -207,38 +247,74 @@ mod avx2 { (array_slice, remainder) } - pub fn adler32_avx2(adler: u32, buf: &[u8]) -> u32 { - if buf.is_empty() { + pub fn adler32_avx2(adler: u32, src: &[u8]) -> u32 { + adler32_avx2_help::(adler, &mut [], src) + } + + pub fn adler32_fold_copy_avx2(adler: u32, dst: &mut [MaybeUninit], src: &[u8]) -> u32 { + adler32_avx2_help::(adler, dst, src) + } + + fn adler32_avx2_help( + adler: u32, + mut dst: &mut [MaybeUninit], + src: &[u8], + ) -> u32 { + if src.is_empty() { return adler; } let mut adler1 = (adler >> 16) & 0xffff; let mut adler0 = adler & 0xffff; - if buf.len() < 16 { - return adler32_len_16(adler0, buf, adler1); - } else if buf.len() < 32 { - return adler32_len_64(adler0, buf, adler1); + if src.len() < 16 { + // use COPY const generic for this branch + if COPY { + return adler32_copy_len_16(adler0, dst, src, adler1); + } else { + return adler32_len_16(adler0, src, adler1); + } + } else if src.len() < 32 { + // use COPY const generic for this branch + if COPY { + return adler32_copy_len_16(adler0, dst, src, adler1); + } else { + return adler32_len_64(adler0, src, adler1); + } } // use largest step possible (without causing overflow) const N: usize = (NMAX - (NMAX % 32)) as usize; - let (chunks, remainder) = slice_as_chunks::<_, N>(buf); + let (chunks, remainder) = slice_as_chunks::<_, N>(src); for chunk in chunks { - (adler0, adler1) = unsafe { helper_32_bytes(adler0, adler1, chunk) }; + (adler0, adler1) = unsafe { helper_32_bytes::(adler0, adler1, dst, chunk) }; + if COPY { + dst = &mut dst[N..]; + } } // then take steps of 32 bytes let (chunks, remainder) = slice_as_chunks::<_, 32>(remainder); for chunk in chunks { - (adler0, adler1) = unsafe { helper_32_bytes(adler0, adler1, chunk) }; + (adler0, adler1) = unsafe { helper_32_bytes::(adler0, adler1, dst, chunk) }; + if COPY { + dst = &mut dst[32..]; + } } if !remainder.is_empty() { if remainder.len() < 16 { - return adler32_len_16(adler0, remainder, adler1); + if COPY { + return adler32_copy_len_16(adler0, dst, remainder, adler1); + } else { + return adler32_len_16(adler0, remainder, adler1); + } } else if remainder.len() < 32 { - return adler32_len_64(adler0, remainder, adler1); + if COPY { + return adler32_copy_len_16(adler0, dst, remainder, adler1); + } else { + return adler32_len_64(adler0, remainder, adler1); + } } else { unreachable!() } @@ -248,8 +324,13 @@ mod avx2 { } #[inline(always)] - unsafe fn helper_32_bytes(mut adler0: u32, mut adler1: u32, buf: &[u8]) -> (u32, u32) { - debug_assert_eq!(buf.len() % 32, 0); + unsafe fn helper_32_bytes( + mut adler0: u32, + mut adler1: u32, + dst: &mut [MaybeUninit], + src: &[u8], + ) -> (u32, u32) { + debug_assert_eq!(src.len() % 32, 0); let mut vs1 = _mm256_zextsi128_si256(_mm_cvtsi32_si128(adler0 as i32)); let mut vs2 = _mm256_zextsi128_si256(_mm_cvtsi32_si128(adler1 as i32)); @@ -257,12 +338,19 @@ mod avx2 { let mut vs1_0 = vs1; let mut vs3 = ZERO; - for chunk in buf.chunks_exact(32) { - let vbuf = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i); + let mut out_chunks = dst.chunks_exact_mut(32); - let vs1_sad = _mm256_sad_epu8(vbuf, ZERO); // Sum of abs diff, resulting in 2 x int32's + for in_chunk in src.chunks_exact(32) { + let vbuf = _mm256_loadu_si256(in_chunk.as_ptr() as *const __m256i); - // TODO copy? + if COPY { + // println!("simd copy {:?}", in_chunk); + let out_chunk = out_chunks.next().unwrap(); + _mm256_storeu_si256(out_chunk.as_mut_ptr() as *mut __m256i, vbuf); + // out_chunk.copy_from_slice(slice_to_uninit(in_chunk)) + } + + let vs1_sad = _mm256_sad_epu8(vbuf, ZERO); // Sum of abs diff, resulting in 2 x int32's vs1 = _mm256_add_epi32(vs1, vs1_sad); vs3 = _mm256_add_epi32(vs3, vs1_0); @@ -302,6 +390,26 @@ mod avx2 { assert_eq!(naive_adler32(1, &vec[..i]), adler32_avx2(1, &vec[..i])); } } + + #[cfg(test)] + // TODO: This could use `MaybeUninit::slice_assume_init` when it is stable. + unsafe fn slice_assume_init(slice: &[MaybeUninit]) -> &[u8] { + &*(slice as *const [MaybeUninit] as *const [u8]) + } + + #[test] + fn fold_copy_copies() { + let src: Vec<_> = (0..128).map(|x| x as u8).collect(); + let mut dst = [MaybeUninit::new(0); 128]; + + for (i, _) in src.iter().enumerate() { + dst.fill(MaybeUninit::new(0)); + + adler32_fold_copy_avx2(1, &mut dst[..i], &src[..i]); + + assert_eq!(&src[..i], unsafe { slice_assume_init(&dst[..i]) }) + } + } } #[cfg(test)] diff --git a/src/window.rs b/src/window.rs index 9de5a397..0509901d 100644 --- a/src/window.rs +++ b/src/window.rs @@ -1,4 +1,4 @@ -use crate::adler32::adler32; +use crate::adler32::{adler32, adler32_fold_copy}; use std::mem::MaybeUninit; // translation guide: @@ -89,10 +89,7 @@ impl<'a> Window<'a> { if checksum != 0 { checksum = adler32(checksum, non_window_slice); - - checksum = adler32(checksum, window_slice); - self.buf - .copy_from_slice(unsafe { slice_to_uninit(window_slice) }); + checksum = adler32_fold_copy(checksum, self.buf, window_slice); } else { self.buf .copy_from_slice(unsafe { slice_to_uninit(window_slice) }); @@ -107,23 +104,20 @@ impl<'a> Window<'a> { // written to the start of the window. let (end_part, start_part) = slice.split_at(dist); - let end_part = unsafe { slice_to_uninit(end_part) }; - let start_part = unsafe { slice_to_uninit(start_part) }; - if checksum != 0 { - // TODO fuse memcpy and adler - checksum = adler32(checksum, slice); - self.buf[self.next..][..end_part.len()].copy_from_slice(end_part); + let dst = &mut self.buf[self.next..][..end_part.len()]; + checksum = adler32_fold_copy(checksum, dst, end_part); } else { + let end_part = unsafe { slice_to_uninit(end_part) }; self.buf[self.next..][..end_part.len()].copy_from_slice(end_part); } if !start_part.is_empty() { if checksum != 0 { - // TODO fuse memcpy and adler - checksum = adler32(checksum, slice); - self.buf[..start_part.len()].copy_from_slice(start_part); + let dst = &mut self.buf[..start_part.len()]; + checksum = adler32_fold_copy(checksum, dst, start_part); } else { + let start_part = unsafe { slice_to_uninit(start_part) }; self.buf[..start_part.len()].copy_from_slice(start_part); }