Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize non-transposed int8 GEMV kernel #559

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions rten-simd/src/arch/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ use std::arch::aarch64::{
vreinterpretq_f32_s32, vshlq_n_s32, vst1q_f32, vst1q_s32, vst1q_u32, vsubq_f32, vsubq_s32,
};

use core::arch::aarch64::{
vreinterpretq_s16_s32, vreinterpretq_s32_s16, vreinterpretq_s32_s8, vreinterpretq_s8_s32,
vzip1q_s16, vzip1q_s8, vzip2q_s16, vzip2q_s8,
};

use crate::{Simd, SimdFloat, SimdInt, SimdMask};

impl SimdMask for uint32x4_t {
Expand Down Expand Up @@ -176,6 +181,38 @@ impl SimdInt for int32x4_t {
vcombine_s32(vreinterpret_s32_s16(abcd.0), vreinterpret_s32_s16(abcd.1))
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
vreinterpretq_s32_s8(vzip1q_s8(
vreinterpretq_s8_s32(self),
vreinterpretq_s8_s32(rhs),
))
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
vreinterpretq_s32_s8(vzip2q_s8(
vreinterpretq_s8_s32(self),
vreinterpretq_s8_s32(rhs),
))
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
vreinterpretq_s32_s16(vzip1q_s16(
vreinterpretq_s16_s32(self),
vreinterpretq_s16_s32(rhs),
))
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
vreinterpretq_s32_s16(vzip2q_s16(
vreinterpretq_s16_s32(self),
vreinterpretq_s16_s32(rhs),
))
}

#[inline]
unsafe fn xor(self, rhs: Self) -> Self {
veorq_s32(self, rhs)
Expand Down
41 changes: 41 additions & 0 deletions rten-simd/src/arch/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::transmute;

use crate::{Simd, SimdFloat, SimdInt, SimdMask};

impl SimdMask for bool {
Expand Down Expand Up @@ -157,6 +159,38 @@ impl SimdInt for i32 {
unsafe fn xor(self, rhs: Self) -> i32 {
self ^ rhs
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
let self_i8 = unsafe { transmute::<i32, [i8; 4]>(self) };
let rhs_i8 = unsafe { transmute::<i32, [i8; 4]>(rhs) };
let lo_i8 = [self_i8[0], rhs_i8[0], self_i8[1], rhs_i8[1]];
unsafe { transmute::<[i8; 4], i32>(lo_i8) }
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
let self_i8 = unsafe { transmute::<i32, [i8; 4]>(self) };
let rhs_i8 = unsafe { transmute::<i32, [i8; 4]>(rhs) };
let hi_i8 = [self_i8[2], rhs_i8[2], self_i8[3], rhs_i8[3]];
unsafe { transmute::<[i8; 4], i32>(hi_i8) }
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
let self_i16 = unsafe { transmute::<i32, [i16; 2]>(self) };
let rhs_i16 = unsafe { transmute::<i32, [i16; 2]>(rhs) };
let lo_i16 = [self_i16[0], rhs_i16[0]];
unsafe { transmute::<[i16; 2], i32>(lo_i16) }
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
let self_i16 = unsafe { transmute::<i32, [i16; 2]>(self) };
let rhs_i16 = unsafe { transmute::<i32, [i16; 2]>(rhs) };
let hi_i16 = [self_i16[1], rhs_i16[1]];
unsafe { transmute::<[i16; 2], i32>(hi_i16) }
}
}

/// Treat an `f32` as a single-lane SIMD "vector".
Expand Down Expand Up @@ -247,3 +281,10 @@ impl SimdFloat for f32 {
self
}
}

#[cfg(test)]
mod tests {
use crate::vec::tests::test_simdint;

test_simdint!(i32_simdint, i32);
}
62 changes: 58 additions & 4 deletions rten-simd/src/arch/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::arch::wasm32::{
f32x4_abs, f32x4_add, f32x4_div, f32x4_extract_lane, f32x4_ge, f32x4_le, f32x4_lt, f32x4_max,
f32x4_min, f32x4_mul, f32x4_nearest, f32x4_splat, f32x4_sub, i32x4, i32x4_add, i32x4_eq,
i32x4_ge, i32x4_gt, i32x4_le, i32x4_lt, i32x4_max, i32x4_min, i32x4_mul, i32x4_shl,
i32x4_shuffle, i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, v128, v128_and, v128_bitselect,
v128_load, v128_store, v128_xor,
f32x4_min, f32x4_mul, f32x4_nearest, f32x4_splat, f32x4_sub, i16x8_shuffle, i32x4, i32x4_add,
i32x4_eq, i32x4_ge, i32x4_gt, i32x4_le, i32x4_lt, i32x4_max, i32x4_min, i32x4_mul, i32x4_shl,
i32x4_shuffle, i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, i8x16_shuffle, v128, v128_and,
v128_bitselect, v128_load, v128_store, v128_xor,
};

#[cfg(target_feature = "relaxed-simd")]
Expand Down Expand Up @@ -181,6 +181,60 @@ impl SimdInt for v128i {
unsafe fn xor(self, rhs: Self) -> Self {
Self(v128_xor(self.0, rhs.0))
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
Self(i8x16_shuffle::<
0,
16,
1,
17,
2,
18,
3,
19,
4,
20,
5,
21,
6,
22,
7,
23,
>(self.0, rhs.0))
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
Self(i8x16_shuffle::<
8,
24,
9,
25,
10,
26,
11,
27,
12,
28,
13,
29,
14,
30,
15,
31,
>(self.0, rhs.0))
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
Self(i16x8_shuffle::<0, 8, 1, 9, 2, 10, 3, 11>(self.0, rhs.0))
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
Self(i16x8_shuffle::<4, 12, 5, 13, 6, 14, 7, 15>(self.0, rhs.0))
}
}

impl Simd for v128f {
Expand Down
111 changes: 103 additions & 8 deletions rten-simd/src/arch/x86_64.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::arch::x86_64::{
__m128i, __m256, __m256i, _mm256_add_epi32, _mm256_add_ps, _mm256_and_si256, _mm256_andnot_ps,
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi256_ps,
_mm256_cmp_ps, _mm256_cmpeq_epi32, _mm256_cmpgt_epi32, _mm256_cvtps_epi32, _mm256_cvttps_epi32,
_mm256_div_ps, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_loadu_si256,
_mm256_max_epi32, _mm256_max_ps, _mm256_min_epi32, _mm256_min_ps, _mm256_mul_ps,
_mm256_mullo_epi32, _mm256_or_si256, _mm256_set1_epi32, _mm256_set1_ps, _mm256_setr_epi32,
_mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256, _mm256_sub_epi32, _mm256_sub_ps,
_mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32, _mm_loadl_epi64, _mm_movehl_ps, _mm_prefetch,
_mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi128_si256,
_mm256_castsi256_ps, _mm256_castsi256_si128, _mm256_cmp_ps, _mm256_cmpeq_epi32,
_mm256_cmpgt_epi32, _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps,
_mm256_extractf128_ps, _mm256_extractf128_si256, _mm256_fmadd_ps, _mm256_insertf128_si256,
_mm256_loadu_ps, _mm256_loadu_si256, _mm256_max_epi32, _mm256_max_ps, _mm256_min_epi32,
_mm256_min_ps, _mm256_mul_ps, _mm256_mullo_epi32, _mm256_or_si256, _mm256_set1_epi32,
_mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256,
_mm256_sub_epi32, _mm256_sub_ps, _mm256_unpackhi_epi16, _mm256_unpackhi_epi8,
_mm256_unpacklo_epi16, _mm256_unpacklo_epi8, _mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32,
_mm_loadl_epi64, _mm_movehl_ps, _mm_prefetch, _mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ,
_CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
};
use std::mem::{transmute, MaybeUninit};

Expand Down Expand Up @@ -234,6 +237,42 @@ impl SimdInt for __m256i {
unsafe fn xor(self, other: Self) -> Self {
_mm256_xor_si256(self, other)
}

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
// Interleave from low half of each 128-bit block.
let lo = _mm256_unpacklo_epi8(self, rhs);
// Interleave from high half of each 128-bit block.
let hi = _mm256_unpackhi_epi8(self, rhs);
// Combine elements from low and high half of first 128-bit block in
// `self` and `rhs`.
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1)
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi8(self, rhs);
let hi = _mm256_unpackhi_epi8(self, rhs);
let lo_hi = _mm256_castsi128_si256(_mm256_extractf128_si256(lo, 1));
let hi_hi = _mm256_extractf128_si256(hi, 1);
_mm256_insertf128_si256(lo_hi, hi_hi, 1)
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi16(self, rhs);
let hi = _mm256_unpackhi_epi16(self, rhs);
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1)
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi16(self, rhs);
let hi = _mm256_unpackhi_epi16(self, rhs);
let lo_hi = _mm256_castsi128_si256(_mm256_extractf128_si256(lo, 1));
let hi_hi = _mm256_extractf128_si256(hi, 1);
_mm256_insertf128_si256(lo_hi, hi_hi, 1)
}
}

impl Simd for __m256 {
Expand Down Expand Up @@ -608,6 +647,62 @@ impl SimdInt for __m512i {
unsafe fn xor(self, other: Self) -> Self {
_mm512_xor_si512(self, other)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_castsi512_si256, _mm512_inserti64x4,
};
let lo_self = _mm512_castsi512_si256(self);
let lo_rhs = _mm512_castsi512_si256(rhs);
let lo = lo_self.zip_lo_i8(lo_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = lo_self.zip_hi_i8(lo_rhs);
_mm512_inserti64x4(lo, hi, 1)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_extracti64x4_epi64, _mm512_inserti64x4,
};
let hi_self = _mm512_extracti64x4_epi64(self, 1);
let hi_rhs = _mm512_extracti64x4_epi64(rhs, 1);
let lo = hi_self.zip_lo_i8(hi_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = hi_self.zip_hi_i8(hi_rhs);
_mm512_inserti64x4(lo, hi, 1)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_castsi512_si256, _mm512_inserti64x4,
};
let lo_self = _mm512_castsi512_si256(self);
let lo_rhs = _mm512_castsi512_si256(rhs);
let lo = lo_self.zip_lo_i16(lo_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = lo_self.zip_hi_i16(lo_rhs);
_mm512_inserti64x4(lo, hi, 1)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_extracti64x4_epi64, _mm512_inserti64x4,
};
let hi_self = _mm512_extracti64x4_epi64(self, 1);
let hi_rhs = _mm512_extracti64x4_epi64(rhs, 1);
let lo = hi_self.zip_lo_i16(hi_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = hi_self.zip_hi_i16(hi_rhs);
_mm512_inserti64x4(lo, hi, 1)
}
}

#[cfg(feature = "avx512")]
Expand Down
Loading