Skip to content

Commit

Permalink
Merge pull request #559 from robertknight/faster-gemv-int8-kernel
Browse files Browse the repository at this point in the history
Optimize non-transposed int8 GEMV kernel
  • Loading branch information
robertknight authored Jan 29, 2025
2 parents ceb3f33 + 1b26bb0 commit b89f8b0
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 61 deletions.
37 changes: 37 additions & 0 deletions rten-simd/src/arch/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ use std::arch::aarch64::{
vreinterpretq_f32_s32, vshlq_n_s32, vst1q_f32, vst1q_s32, vst1q_u32, vsubq_f32, vsubq_s32,
};

use core::arch::aarch64::{
vreinterpretq_s16_s32, vreinterpretq_s32_s16, vreinterpretq_s32_s8, vreinterpretq_s8_s32,
vzip1q_s16, vzip1q_s8, vzip2q_s16, vzip2q_s8,
};

use crate::{Simd, SimdFloat, SimdInt, SimdMask};

impl SimdMask for uint32x4_t {
Expand Down Expand Up @@ -176,6 +181,38 @@ impl SimdInt for int32x4_t {
vcombine_s32(vreinterpret_s32_s16(abcd.0), vreinterpret_s32_s16(abcd.1))
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
vreinterpretq_s32_s8(vzip1q_s8(
vreinterpretq_s8_s32(self),
vreinterpretq_s8_s32(rhs),
))
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
vreinterpretq_s32_s8(vzip2q_s8(
vreinterpretq_s8_s32(self),
vreinterpretq_s8_s32(rhs),
))
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
vreinterpretq_s32_s16(vzip1q_s16(
vreinterpretq_s16_s32(self),
vreinterpretq_s16_s32(rhs),
))
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
vreinterpretq_s32_s16(vzip2q_s16(
vreinterpretq_s16_s32(self),
vreinterpretq_s16_s32(rhs),
))
}

#[inline]
unsafe fn xor(self, rhs: Self) -> Self {
veorq_s32(self, rhs)
Expand Down
41 changes: 41 additions & 0 deletions rten-simd/src/arch/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::transmute;

use crate::{Simd, SimdFloat, SimdInt, SimdMask};

impl SimdMask for bool {
Expand Down Expand Up @@ -157,6 +159,38 @@ impl SimdInt for i32 {
unsafe fn xor(self, rhs: Self) -> i32 {
self ^ rhs
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
let self_i8 = unsafe { transmute::<i32, [i8; 4]>(self) };
let rhs_i8 = unsafe { transmute::<i32, [i8; 4]>(rhs) };
let lo_i8 = [self_i8[0], rhs_i8[0], self_i8[1], rhs_i8[1]];
unsafe { transmute::<[i8; 4], i32>(lo_i8) }
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
let self_i8 = unsafe { transmute::<i32, [i8; 4]>(self) };
let rhs_i8 = unsafe { transmute::<i32, [i8; 4]>(rhs) };
let hi_i8 = [self_i8[2], rhs_i8[2], self_i8[3], rhs_i8[3]];
unsafe { transmute::<[i8; 4], i32>(hi_i8) }
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
let self_i16 = unsafe { transmute::<i32, [i16; 2]>(self) };
let rhs_i16 = unsafe { transmute::<i32, [i16; 2]>(rhs) };
let lo_i16 = [self_i16[0], rhs_i16[0]];
unsafe { transmute::<[i16; 2], i32>(lo_i16) }
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
let self_i16 = unsafe { transmute::<i32, [i16; 2]>(self) };
let rhs_i16 = unsafe { transmute::<i32, [i16; 2]>(rhs) };
let hi_i16 = [self_i16[1], rhs_i16[1]];
unsafe { transmute::<[i16; 2], i32>(hi_i16) }
}
}

/// Treat an `f32` as a single-lane SIMD "vector".
Expand Down Expand Up @@ -247,3 +281,10 @@ impl SimdFloat for f32 {
self
}
}

#[cfg(test)]
mod tests {
use crate::vec::tests::test_simdint;

test_simdint!(i32_simdint, i32);
}
62 changes: 58 additions & 4 deletions rten-simd/src/arch/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::arch::wasm32::{
f32x4_abs, f32x4_add, f32x4_div, f32x4_extract_lane, f32x4_ge, f32x4_le, f32x4_lt, f32x4_max,
f32x4_min, f32x4_mul, f32x4_nearest, f32x4_splat, f32x4_sub, i32x4, i32x4_add, i32x4_eq,
i32x4_ge, i32x4_gt, i32x4_le, i32x4_lt, i32x4_max, i32x4_min, i32x4_mul, i32x4_shl,
i32x4_shuffle, i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, v128, v128_and, v128_bitselect,
v128_load, v128_store, v128_xor,
f32x4_min, f32x4_mul, f32x4_nearest, f32x4_splat, f32x4_sub, i16x8_shuffle, i32x4, i32x4_add,
i32x4_eq, i32x4_ge, i32x4_gt, i32x4_le, i32x4_lt, i32x4_max, i32x4_min, i32x4_mul, i32x4_shl,
i32x4_shuffle, i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, i8x16_shuffle, v128, v128_and,
v128_bitselect, v128_load, v128_store, v128_xor,
};

#[cfg(target_feature = "relaxed-simd")]
Expand Down Expand Up @@ -181,6 +181,60 @@ impl SimdInt for v128i {
unsafe fn xor(self, rhs: Self) -> Self {
Self(v128_xor(self.0, rhs.0))
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
Self(i8x16_shuffle::<
0,
16,
1,
17,
2,
18,
3,
19,
4,
20,
5,
21,
6,
22,
7,
23,
>(self.0, rhs.0))
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
Self(i8x16_shuffle::<
8,
24,
9,
25,
10,
26,
11,
27,
12,
28,
13,
29,
14,
30,
15,
31,
>(self.0, rhs.0))
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
Self(i16x8_shuffle::<0, 8, 1, 9, 2, 10, 3, 11>(self.0, rhs.0))
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
Self(i16x8_shuffle::<4, 12, 5, 13, 6, 14, 7, 15>(self.0, rhs.0))
}
}

impl Simd for v128f {
Expand Down
111 changes: 103 additions & 8 deletions rten-simd/src/arch/x86_64.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::arch::x86_64::{
__m128i, __m256, __m256i, _mm256_add_epi32, _mm256_add_ps, _mm256_and_si256, _mm256_andnot_ps,
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi256_ps,
_mm256_cmp_ps, _mm256_cmpeq_epi32, _mm256_cmpgt_epi32, _mm256_cvtps_epi32, _mm256_cvttps_epi32,
_mm256_div_ps, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_loadu_si256,
_mm256_max_epi32, _mm256_max_ps, _mm256_min_epi32, _mm256_min_ps, _mm256_mul_ps,
_mm256_mullo_epi32, _mm256_or_si256, _mm256_set1_epi32, _mm256_set1_ps, _mm256_setr_epi32,
_mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256, _mm256_sub_epi32, _mm256_sub_ps,
_mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32, _mm_loadl_epi64, _mm_movehl_ps, _mm_prefetch,
_mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi128_si256,
_mm256_castsi256_ps, _mm256_castsi256_si128, _mm256_cmp_ps, _mm256_cmpeq_epi32,
_mm256_cmpgt_epi32, _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps,
_mm256_extractf128_ps, _mm256_extractf128_si256, _mm256_fmadd_ps, _mm256_insertf128_si256,
_mm256_loadu_ps, _mm256_loadu_si256, _mm256_max_epi32, _mm256_max_ps, _mm256_min_epi32,
_mm256_min_ps, _mm256_mul_ps, _mm256_mullo_epi32, _mm256_or_si256, _mm256_set1_epi32,
_mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256,
_mm256_sub_epi32, _mm256_sub_ps, _mm256_unpackhi_epi16, _mm256_unpackhi_epi8,
_mm256_unpacklo_epi16, _mm256_unpacklo_epi8, _mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32,
_mm_loadl_epi64, _mm_movehl_ps, _mm_prefetch, _mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ,
_CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
};
use std::mem::{transmute, MaybeUninit};

Expand Down Expand Up @@ -234,6 +237,42 @@ impl SimdInt for __m256i {
unsafe fn xor(self, other: Self) -> Self {
_mm256_xor_si256(self, other)
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
// Interleave from low half of each 128-bit block.
let lo = _mm256_unpacklo_epi8(self, rhs);
// Interleave from high half of each 128-bit block.
let hi = _mm256_unpackhi_epi8(self, rhs);
// Combine elements from low and high half of first 128-bit block in
// `self` and `rhs`.
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1)
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi8(self, rhs);
let hi = _mm256_unpackhi_epi8(self, rhs);
let lo_hi = _mm256_castsi128_si256(_mm256_extractf128_si256(lo, 1));
let hi_hi = _mm256_extractf128_si256(hi, 1);
_mm256_insertf128_si256(lo_hi, hi_hi, 1)
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi16(self, rhs);
let hi = _mm256_unpackhi_epi16(self, rhs);
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1)
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi16(self, rhs);
let hi = _mm256_unpackhi_epi16(self, rhs);
let lo_hi = _mm256_castsi128_si256(_mm256_extractf128_si256(lo, 1));
let hi_hi = _mm256_extractf128_si256(hi, 1);
_mm256_insertf128_si256(lo_hi, hi_hi, 1)
}
}

impl Simd for __m256 {
Expand Down Expand Up @@ -608,6 +647,62 @@ impl SimdInt for __m512i {
unsafe fn xor(self, other: Self) -> Self {
_mm512_xor_si512(self, other)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_castsi512_si256, _mm512_inserti64x4,
};
let lo_self = _mm512_castsi512_si256(self);
let lo_rhs = _mm512_castsi512_si256(rhs);
let lo = lo_self.zip_lo_i8(lo_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = lo_self.zip_hi_i8(lo_rhs);
_mm512_inserti64x4(lo, hi, 1)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_extracti64x4_epi64, _mm512_inserti64x4,
};
let hi_self = _mm512_extracti64x4_epi64(self, 1);
let hi_rhs = _mm512_extracti64x4_epi64(rhs, 1);
let lo = hi_self.zip_lo_i8(hi_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = hi_self.zip_hi_i8(hi_rhs);
_mm512_inserti64x4(lo, hi, 1)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_castsi512_si256, _mm512_inserti64x4,
};
let lo_self = _mm512_castsi512_si256(self);
let lo_rhs = _mm512_castsi512_si256(rhs);
let lo = lo_self.zip_lo_i16(lo_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = lo_self.zip_hi_i16(lo_rhs);
_mm512_inserti64x4(lo, hi, 1)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_extracti64x4_epi64, _mm512_inserti64x4,
};
let hi_self = _mm512_extracti64x4_epi64(self, 1);
let hi_rhs = _mm512_extracti64x4_epi64(rhs, 1);
let lo = hi_self.zip_lo_i16(hi_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = hi_self.zip_hi_i16(hi_rhs);
_mm512_inserti64x4(lo, hi, 1)
}
}

#[cfg(feature = "avx512")]
Expand Down
Loading

0 comments on commit b89f8b0

Please sign in to comment.