Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement im2col packing for int8 GEMM #570

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rten-simd/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub trait Simd: Copy + Sized {
type Array: Copy
+ std::fmt::Debug
+ std::ops::Index<usize, Output = Self::Elem>
+ std::ops::IndexMut<usize, Output = Self::Elem>
+ AsRef<[Self::Elem]>;

/// Combine elements of `self` and `rhs` according to a mask.
Expand Down
92 changes: 84 additions & 8 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,14 @@ impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT> GemmExecutor<LhsT, RhsT, OutT
///
/// The number of columns in [`ColOffsets`] must be a multiple of this.
pub fn im2col_col_count_step(&self) -> usize {
self.kernel.nr()
self.kernel.im2col_col_count_step()
}

/// Return row count step for building an [`Im2Col`] input.
///
/// The number of rows in [`RowOffsets`] must be a multiple of this.
pub fn im2col_row_count_step(&self) -> usize {
self.kernel.im2col_row_count_step()
}

/// Prepack a matrix for use as the right-hand or "B" matrix input.
Expand Down Expand Up @@ -2186,21 +2193,34 @@ mod tests {
// This builds a mapping between elements of an image and a
// `[chans, height x width]` matrix where `image[c, y, x]` maps to
// `im2col_matrix[c, y / width, y % width]`.
fn build_im2col(image: NdTensorView<f32, 3>, col_count_step: usize) -> Im2Col<f32> {
fn build_im2col<T: Copy>(
image: NdTensorView<T, 3>,
col_count_step: usize,
row_count_step: usize,
) -> Im2Col<T> {
let [chans, img_h, img_w] = image.shape();
let [chan_stride, h_stride, w_stride] = image.strides();

let rows = chans;
let n_cols = img_w * img_h;
let n_cols_padded = n_cols.next_multiple_of(col_count_step);

let row_offsets = RowOffsets {
let rows = chans;
let n_rows_padded = rows.next_multiple_of(row_count_step);

let mut row_offsets = RowOffsets {
chan: (0..rows as i32)
.map(|chan| chan * chan_stride as i32)
.collect(),
y: vec![0; rows],
x: vec![0; rows],
};

for _ in rows..n_rows_padded {
row_offsets.chan.push(i32::MAX);
row_offsets.x.push(i32::MAX);
row_offsets.y.push(i32::MAX);
}

let mut col_offsets = ColOffsets {
y: (0..n_cols)
.map(|i| i as i32 / img_w as i32)
Expand All @@ -2212,8 +2232,8 @@ mod tests {
.collect(),
};
for _ in n_cols..n_cols_padded {
col_offsets.y.push(0);
col_offsets.x.push(0);
col_offsets.y.push(i32::MAX);
col_offsets.x.push(i32::MAX);
}

let max_y_offset = (img_h - 1) * h_stride;
Expand All @@ -2224,13 +2244,14 @@ mod tests {
row_offsets,
col_offsets,
n_cols,
n_rows: rows,
max_y_offset: max_y_offset as i32,
max_x_offset: max_x_offset as i32,
}
}

#[test]
fn test_gemm_im2col() -> Result<(), Box<dyn Error>> {
fn test_gemm_im2col_f32() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let gemm = GemmExecutor::default();

Expand All @@ -2241,7 +2262,11 @@ mod tests {
let kernel_chans = 3;

let img = NdTensor::<f32, 3>::rand([img_chans, img_h, img_w], &mut rng);
let im2col = build_im2col(img.view(), gemm.im2col_col_count_step());
let im2col = build_im2col(
img.view(),
gemm.im2col_col_count_step(),
gemm.im2col_row_count_step(),
);

let kernel_mat = NdTensor::<f32, 2>::rand([kernel_chans, img_chans], &mut rng);
let mut output_mat = NdTensor::<f32, 2>::zeros([kernel_chans, img_h * img_w]);
Expand Down Expand Up @@ -2275,6 +2300,57 @@ mod tests {
Ok(())
}

#[test]
fn test_gemm_im2col_u8i8_i32() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);

// nb. If the test fails, debug by setting dimensions to 1.
let img_h = 2;
let img_w = 2;
let img_chans = 2;
let kernel_chans = 3;

let img = NdTensor::<i8, 3>::rand([img_chans, img_h, img_w], &mut rng);

for gemm in all_gemms() {
let im2col = build_im2col(
img.view(),
gemm.im2col_col_count_step(),
gemm.im2col_row_count_step(),
);
let kernel_mat = NdTensor::<u8, 2>::rand([kernel_chans, img_chans], &mut rng);
let mut output_mat = NdTensor::<i32, 2>::zeros([kernel_chans, img_h * img_w]);
let out_row_stride = output_mat.row_stride();

gemm.gemm(
output_mat.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(kernel_mat.view()),
GemmInputB::Im2Col(&im2col),
1., // alpha
0, // beta
None, // bias
None, // a_quant
None, // b_quant
)
.unwrap();

let mut expected = NdTensor::<i32, 2>::zeros([kernel_chans, im2col.cols()]);
for i in 0..expected.rows() {
for j in 0..expected.cols() {
let mut acc = 0;
for k in 0..kernel_mat.cols() {
acc += kernel_mat[[i, k]] as i32 * img[[k, j / img_w, j % img_w]] as i32;
}
expected[[i, j]] = acc;
}
}
expect_equal(&output_mat, &expected)?;
}

Ok(())
}

#[test]
fn test_gemv() -> Result<(), Box<dyn Error>> {
enum Strides {
Expand Down
152 changes: 147 additions & 5 deletions src/gemm/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use std::mem::MaybeUninit;
use std::ops::Range;

use rten_simd::{SimdInt, SimdMask};

use rten_tensor::{NdTensorView, Storage};

use super::packing::int8::shift_cast_i8_u8;
use crate::slice_cast::cast_pod_mut_slice;

/// Maps rows of an [`Im2Col`] matrix to locations in the source image.
///
/// For efficiency when packing the image, the locations are premultiplied by
Expand Down Expand Up @@ -45,17 +47,26 @@ pub struct Im2Col<'a, T> {

/// Map of im2col row index to input image coordinate, premultiplied with
/// the corresponding stride.
///
/// The arrays may be padded to a multiple of a step size specified by the
/// GEMM kernel. `n_rows` contains the actual number of rows in the virtual
/// matrix.
pub row_offsets: RowOffsets,

/// Map of im2col column index to input image coordinate, premultiplied with
/// the corresponding stride. The length of arrays in `col_offsets` is
/// rounded up to the nearest multiple of the panel width. `n_cols` contains
/// the actual number of columns in the virtual matrix.
/// the corresponding stride.
///
/// The arrays may be padded to a multiple of a step size specified by the
/// GEMM kernel. `n_cols` contains the actual number of columns in the
/// virtual matrix.
pub col_offsets: ColOffsets,

/// Number of columns in the im2col matrix.
pub n_cols: usize,

/// Number of rows in the im2col matrix.
pub n_rows: usize,

/// Maximum valid sum of `row_offsets.y + col_offsets.y`. Values above this
/// correspond to the padding region.
pub max_y_offset: i32,
Expand All @@ -68,7 +79,7 @@ pub struct Im2Col<'a, T> {
impl<T: Copy + Default> Im2Col<'_, T> {
/// Return the number of rows in the im2col matrix.
pub fn rows(&self) -> usize {
self.row_offsets.chan.len()
self.n_rows
}

/// Return the number of columns in the im2col matrix.
Expand All @@ -78,6 +89,9 @@ impl<T: Copy + Default> Im2Col<'_, T> {

/// Pack part of an image into a packing buffer.
///
/// This method is for use by kernels using the "standard" packing buffer
/// layout for the B / RHS input.
///
/// `NR_REGS` specifies the width of each column panel as a multiple of
/// `S::LEN` elements. In other words, `panel_width` must exactly equal
/// `NR_REGS * S::LEN`.
Expand Down Expand Up @@ -188,3 +202,131 @@ impl<T: Copy + Default> Im2Col<'_, T> {
assert_eq!(out_offset, used_size);
}
}

impl Im2Col<'_, i8> {
/// Pack part of an image into a packing buffer.
///
/// This method is for use by kernels using int8 dot product instructions
/// to compute `S::LEN x i32` dot products from two input vectors each
/// containing `S::LEN x 4 x i8` (or u8) inputs.
#[inline(always)]
#[allow(unused)] // Some architectures only
pub(super) unsafe fn pack_block_i8_dot<S: SimdInt>(
&self,
out: &mut [MaybeUninit<i8>],
rows: Range<usize>,
cols: Range<usize>,
) {
self.pack_block_int8::<S, false>(out, rows, cols);
}

/// Variant of [`pack_block_i8_dot`](Self::pack_block_i8_dot) which shifts
/// i8 values to u8 by adding 128.
#[inline(always)]
#[allow(unused)] // Some architectures only
pub(super) unsafe fn pack_block_i8_dot_cast_u8<S: SimdInt>(
&self,
out: &mut [MaybeUninit<u8>],
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).unwrap();
self.pack_block_int8::<S, true>(out, rows, cols);
}

#[inline(always)]
unsafe fn pack_block_int8<S: SimdInt, const CAST_B_U8: bool>(
&self,
out: &mut [MaybeUninit<i8>],
rows: Range<usize>,
cols: Range<usize>,
) {
const K_TILE: usize = size_of::<i32>() / size_of::<i8>();

debug_assert!(rows.end <= self.rows());
debug_assert!(cols.end <= self.cols());

let max_x_offset = S::splat(self.max_x_offset);
let max_y_offset = S::splat(self.max_y_offset);

let col_x_offsets = &self.col_offsets.x;
debug_assert_eq!(col_x_offsets.len() % S::LEN, 0);

let col_y_offsets = &self.col_offsets.y;
debug_assert_eq!(col_y_offsets.len() % S::LEN, 0);

let row_x_offsets = &self.row_offsets.x;
debug_assert_eq!(row_x_offsets.len() % K_TILE, 0);

let row_y_offsets = &self.row_offsets.y;
debug_assert_eq!(row_y_offsets.len() % K_TILE, 0);

let row_chan_offsets = &self.row_offsets.chan;
debug_assert_eq!(row_chan_offsets.len() % K_TILE, 0);

let img_ptr = self.image.storage().as_ptr();
let out_ptr = out.as_mut_ptr();

let mut out_offset = 0;

for start_col in cols.step_by(S::LEN) {
let col_y_offset = S::load(col_y_offsets.get_unchecked(start_col));
let col_x_offset = S::load(col_x_offsets.get_unchecked(start_col));
let zero = S::zero();

let mut col_sums = S::zero().to_array();

for start_row in rows.clone().step_by(4) {
for i in 0..K_TILE {
let k = start_row + i;
let row_x_offset = S::splat(*row_x_offsets.get_unchecked(k));
let row_y_offset = S::splat(*row_y_offsets.get_unchecked(k));
let row_chan_offset = S::splat(*row_chan_offsets.get_unchecked(k));

let x_offsets = row_x_offset.add(col_x_offset);
let y_offsets = row_y_offset.add(col_y_offset);
let offsets = x_offsets.add(y_offsets).add(row_chan_offset);

let pad_mask = y_offsets
.ge(zero)
.and(y_offsets.le(max_y_offset))
.and(x_offsets.ge(zero))
.and(x_offsets.le(max_x_offset));
let pad_mask_array = pad_mask.to_array();

// Set offsets to zero for padding elements. We require
// this offset is always valid.
let offsets_array = zero.blend(offsets, pad_mask).to_array();

for idx in 0..S::LEN {
let out_ptr = out_ptr.add(out_offset + idx * K_TILE + i);
let src_elem = *img_ptr.add(offsets_array[idx] as usize);

// This should be compiled to a conditional move.
let elem = if pad_mask_array[idx] { src_elem } else { 0 };

if CAST_B_U8 {
let elem = shift_cast_i8_u8(elem);
col_sums[idx] += elem as i32;
out_ptr.write(MaybeUninit::new(elem as i8));
} else {
col_sums[idx] += elem as i32;
out_ptr.write(MaybeUninit::new(elem));
}
}
}
out_offset += S::LEN * K_TILE;
}

// Store column sums at end of each panel.
let col_sum_ptr = out_ptr.add(out_offset) as *mut i32;
for i in 0..S::LEN {
*col_sum_ptr.add(i) = col_sums[i];
}
out_offset += S::LEN * K_TILE;
}

// Sanity check
assert_eq!(out_offset, out.len());
}
}
22 changes: 22 additions & 0 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,32 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
/// Return a name for this kernel for use in logging etc.
fn name(&self) -> &'static str;

/// Return true if this kernel may encounter saturation in a data type that
/// is smaller than the accumulator.
///
/// The caller will have to prepare inputs (usually the weights) to avoid
/// this. This is primarily an issue for x64 systems without VNNI.
/// See https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html.
fn may_saturate(&self) -> bool {
false
}

/// Step size used when packing an image usage [`pack_im2col`](Kernel::pack_im2col).
///
/// The length of the offset arrays in [`Im2Col::row_offsets`] must be a
/// multiple of this.
fn im2col_row_count_step(&self) -> usize {
1
}

/// Step size used when packing an image usage [`pack_im2col`](Kernel::pack_im2col).
///
/// The length of the offset arrays in [`Im2Col::col_offsets`] must be a
/// multiple of this.
fn im2col_col_count_step(&self) -> usize {
self.nr()
}

/// Return the layout of a packing buffer required to pack an A / LHS input.
fn packed_a_layout(
&self,
Expand Down
Loading