Skip to content

Commit

Permalink
fused adler
Browse files Browse the repository at this point in the history
  • Loading branch information
folkertdev committed Jan 15, 2024
1 parent 5deeeb0 commit 9f7d3f2
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 33 deletions.
146 changes: 127 additions & 19 deletions src/adler32.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::MaybeUninit;

pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 {
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx2") {
Expand All @@ -7,6 +9,25 @@ pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 {
adler32_rust(start_checksum, data)
}

pub fn adler32_fold_copy(start_checksum: u32, dst: &mut [MaybeUninit<u8>], src: &[u8]) -> u32 {
debug_assert!(dst.len() >= src.len(), "{} < {}", dst.len(), src.len());

#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx2") {
return avx2::adler32_fold_copy_avx2(start_checksum, dst, src);
}

let adler = adler32_rust(start_checksum, src);
dst[..src.len()].copy_from_slice(slice_to_uninit(src));
adler
}

// when stable, use MaybeUninit::write_slice
fn slice_to_uninit(slice: &[u8]) -> &[MaybeUninit<u8>] {
// safety: &[T] and &[MaybeUninit<T>] have the same layout
unsafe { &*(slice as *const [u8] as *const [MaybeUninit<u8>]) }
}

// inefficient but correct, useful for testing
#[cfg(test)]
fn naive_adler32(start_checksum: u32, data: &[u8]) -> u32 {
Expand Down Expand Up @@ -124,6 +145,25 @@ fn adler32_len_16(mut adler: u32, buf: &[u8], mut sum2: u32) -> u32 {
adler | (sum2 << 16)
}

fn adler32_copy_len_16(
mut adler: u32,
dst: &mut [MaybeUninit<u8>],
src: &[u8],
mut sum2: u32,
) -> u32 {
for (source, destination) in src.iter().zip(dst.iter_mut()) {
let v = *source;
*destination = MaybeUninit::new(v);
adler += v as u32;
sum2 += adler;
}

adler %= BASE;
sum2 %= BASE; /* only added so many BASE's */
/* return recombined sums */
adler | (sum2 << 16)
}

fn adler32_len_64(mut adler: u32, buf: &[u8], mut sum2: u32) -> u32 {
const N: usize = if UNROLL_MORE { 16 } else { 8 };
let mut it = buf.chunks_exact(N);
Expand All @@ -145,8 +185,8 @@ mod avx2 {
use std::arch::x86_64::{
__m256i, _mm256_add_epi32, _mm256_castsi256_si128, _mm256_extracti128_si256,
_mm256_loadu_si256, _mm256_madd_epi16, _mm256_maddubs_epi16, _mm256_permutevar8x32_epi32,
_mm256_sad_epu8, _mm256_slli_epi32, _mm256_zextsi128_si256, _mm_add_epi32,
_mm_cvtsi128_si32, _mm_cvtsi32_si128, _mm_shuffle_epi32, _mm_unpackhi_epi64,
_mm256_sad_epu8, _mm256_slli_epi32, _mm256_storeu_si256, _mm256_zextsi128_si256,
_mm_add_epi32, _mm_cvtsi128_si32, _mm_cvtsi32_si128, _mm_shuffle_epi32, _mm_unpackhi_epi64,
};

const fn __m256i_literal(bytes: [u8; 32]) -> __m256i {
Expand Down Expand Up @@ -207,38 +247,74 @@ mod avx2 {
(array_slice, remainder)
}

pub fn adler32_avx2(adler: u32, buf: &[u8]) -> u32 {
if buf.is_empty() {
pub fn adler32_avx2(adler: u32, src: &[u8]) -> u32 {
adler32_avx2_help::<false>(adler, &mut [], src)
}

pub fn adler32_fold_copy_avx2(adler: u32, dst: &mut [MaybeUninit<u8>], src: &[u8]) -> u32 {
adler32_avx2_help::<true>(adler, dst, src)
}

fn adler32_avx2_help<const COPY: bool>(
adler: u32,
mut dst: &mut [MaybeUninit<u8>],
src: &[u8],
) -> u32 {
if src.is_empty() {
return adler;
}

let mut adler1 = (adler >> 16) & 0xffff;
let mut adler0 = adler & 0xffff;

if buf.len() < 16 {
return adler32_len_16(adler0, buf, adler1);
} else if buf.len() < 32 {
return adler32_len_64(adler0, buf, adler1);
if src.len() < 16 {
// use COPY const generic for this branch
if COPY {
return adler32_copy_len_16(adler0, dst, src, adler1);
} else {
return adler32_len_16(adler0, src, adler1);
}
} else if src.len() < 32 {
// use COPY const generic for this branch
if COPY {
return adler32_copy_len_16(adler0, dst, src, adler1);
} else {
return adler32_len_64(adler0, src, adler1);
}
}

// use largest step possible (without causing overflow)
const N: usize = (NMAX - (NMAX % 32)) as usize;
let (chunks, remainder) = slice_as_chunks::<_, N>(buf);
let (chunks, remainder) = slice_as_chunks::<_, N>(src);
for chunk in chunks {
(adler0, adler1) = unsafe { helper_32_bytes(adler0, adler1, chunk) };
(adler0, adler1) = unsafe { helper_32_bytes::<COPY>(adler0, adler1, dst, chunk) };
if COPY {
dst = &mut dst[N..];
}
}

// then take steps of 32 bytes
let (chunks, remainder) = slice_as_chunks::<_, 32>(remainder);
for chunk in chunks {
(adler0, adler1) = unsafe { helper_32_bytes(adler0, adler1, chunk) };
(adler0, adler1) = unsafe { helper_32_bytes::<COPY>(adler0, adler1, dst, chunk) };
if COPY {
dst = &mut dst[32..];
}
}

if !remainder.is_empty() {
if remainder.len() < 16 {
return adler32_len_16(adler0, remainder, adler1);
if COPY {
return adler32_copy_len_16(adler0, dst, remainder, adler1);
} else {
return adler32_len_16(adler0, remainder, adler1);
}
} else if remainder.len() < 32 {
return adler32_len_64(adler0, remainder, adler1);
if COPY {
return adler32_copy_len_16(adler0, dst, remainder, adler1);
} else {
return adler32_len_64(adler0, remainder, adler1);
}
} else {
unreachable!()
}
Expand All @@ -248,21 +324,33 @@ mod avx2 {
}

#[inline(always)]
unsafe fn helper_32_bytes(mut adler0: u32, mut adler1: u32, buf: &[u8]) -> (u32, u32) {
debug_assert_eq!(buf.len() % 32, 0);
unsafe fn helper_32_bytes<const COPY: bool>(
mut adler0: u32,
mut adler1: u32,
dst: &mut [MaybeUninit<u8>],
src: &[u8],
) -> (u32, u32) {
debug_assert_eq!(src.len() % 32, 0);

let mut vs1 = _mm256_zextsi128_si256(_mm_cvtsi32_si128(adler0 as i32));
let mut vs2 = _mm256_zextsi128_si256(_mm_cvtsi32_si128(adler1 as i32));

let mut vs1_0 = vs1;
let mut vs3 = ZERO;

for chunk in buf.chunks_exact(32) {
let vbuf = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
let mut out_chunks = dst.chunks_exact_mut(32);

let vs1_sad = _mm256_sad_epu8(vbuf, ZERO); // Sum of abs diff, resulting in 2 x int32's
for in_chunk in src.chunks_exact(32) {
let vbuf = _mm256_loadu_si256(in_chunk.as_ptr() as *const __m256i);

// TODO copy?
if COPY {
// println!("simd copy {:?}", in_chunk);
let out_chunk = out_chunks.next().unwrap();
_mm256_storeu_si256(out_chunk.as_mut_ptr() as *mut __m256i, vbuf);
// out_chunk.copy_from_slice(slice_to_uninit(in_chunk))
}

let vs1_sad = _mm256_sad_epu8(vbuf, ZERO); // Sum of abs diff, resulting in 2 x int32's

vs1 = _mm256_add_epi32(vs1, vs1_sad);
vs3 = _mm256_add_epi32(vs3, vs1_0);
Expand Down Expand Up @@ -302,6 +390,26 @@ mod avx2 {
assert_eq!(naive_adler32(1, &vec[..i]), adler32_avx2(1, &vec[..i]));
}
}

#[cfg(test)]
// TODO: This could use `MaybeUninit::slice_assume_init` when it is stable.
unsafe fn slice_assume_init(slice: &[MaybeUninit<u8>]) -> &[u8] {
&*(slice as *const [MaybeUninit<u8>] as *const [u8])
}

#[test]
fn fold_copy_copies() {
let src: Vec<_> = (0..128).map(|x| x as u8).collect();
let mut dst = [MaybeUninit::new(0); 128];

for (i, _) in src.iter().enumerate() {
dst.fill(MaybeUninit::new(0));

adler32_fold_copy_avx2(1, &mut dst[..i], &src[..i]);

assert_eq!(&src[..i], unsafe { slice_assume_init(&dst[..i]) })
}
}
}

#[cfg(test)]
Expand Down
22 changes: 8 additions & 14 deletions src/window.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::adler32::adler32;
use crate::adler32::{adler32, adler32_fold_copy};
use std::mem::MaybeUninit;

// translation guide:
Expand Down Expand Up @@ -89,10 +89,7 @@ impl<'a> Window<'a> {

if checksum != 0 {
checksum = adler32(checksum, non_window_slice);

checksum = adler32(checksum, window_slice);
self.buf
.copy_from_slice(unsafe { slice_to_uninit(window_slice) });
checksum = adler32_fold_copy(checksum, self.buf, window_slice);
} else {
self.buf
.copy_from_slice(unsafe { slice_to_uninit(window_slice) });
Expand All @@ -107,23 +104,20 @@ impl<'a> Window<'a> {
// written to the start of the window.
let (end_part, start_part) = slice.split_at(dist);

let end_part = unsafe { slice_to_uninit(end_part) };
let start_part = unsafe { slice_to_uninit(start_part) };

if checksum != 0 {
// TODO fuse memcpy and adler
checksum = adler32(checksum, slice);
self.buf[self.next..][..end_part.len()].copy_from_slice(end_part);
let dst = &mut self.buf[self.next..][..end_part.len()];
checksum = adler32_fold_copy(checksum, dst, end_part);
} else {
let end_part = unsafe { slice_to_uninit(end_part) };
self.buf[self.next..][..end_part.len()].copy_from_slice(end_part);
}

if !start_part.is_empty() {
if checksum != 0 {
// TODO fuse memcpy and adler
checksum = adler32(checksum, slice);
self.buf[..start_part.len()].copy_from_slice(start_part);
let dst = &mut self.buf[..start_part.len()];
checksum = adler32_fold_copy(checksum, dst, start_part);
} else {
let start_part = unsafe { slice_to_uninit(start_part) };
self.buf[..start_part.len()].copy_from_slice(start_part);
}

Expand Down

0 comments on commit 9f7d3f2

Please sign in to comment.