diff --git a/src/gemm.rs b/src/gemm.rs index 89ffedbc..2f8c7912 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -1384,6 +1384,12 @@ fn gemm_block( }); } +#[cfg(test)] +mod reduced_range_rng; + +#[cfg(test)] +pub use reduced_range_rng::ReducedRangeRng; + #[cfg(test)] mod tests { use std::error::Error; @@ -1397,7 +1403,7 @@ mod tests { use super::{ BiasVector, ColOffsets, F32KernelType, GemmError, GemmExecutor, GemmInT, GemmInputA, - GemmInputB, GemmOutT, Im2Col, QuantParams, RowOffsets, WithKernel, + GemmInputB, GemmOutT, Im2Col, QuantParams, ReducedRangeRng, RowOffsets, WithKernel, }; /// Scale a possibly non-float value by a float. @@ -1628,44 +1634,6 @@ mod tests { .filter_map(|kern_type| GemmExecutor::::with_kernel(kern_type)) } - // Random number generator which produces values with an optionally reduced - // range. - // - // This works around an issue under AVX2 where the `vpmaddubsw` instruction - // can encounter saturation when adding two signed 16-bit values into a - // 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x - // i8` multiplication. By limiting the range of either the u8 or i8 input, - // we can avoid saturation. This issue does not affect the VNNI instruction - // used on newer x64 systems. - // - // To match the workaround in ONNX Runtime's quantizer when - // `reduce_range=True` is enabled, the range of the RHS (ie. the weights) - // is limited. - // - // To avoid saturation we require `a_max * b_max * 2 <= i16::MAX`. This - // re-arranges to `b_max <= (i16::MAX / 2) / 255 <= 64`. - struct ReducedRangeRng { - reduce_range: bool, - rng: XorShiftRng, - } - - impl ReducedRangeRng { - fn new(reduce_range: bool) -> Self { - Self { - rng: XorShiftRng::new(1234), - reduce_range, - } - } - - fn next_i8(&mut self) -> i8 { - if self.reduce_range { - (self.rng.next_u64() % 65) as i8 - } else { - self.rng.next_u64() as i8 - } - } - } - // Simplest possible test case for easy debugging. #[test] fn test_simple_gemm_f32() -> Result<(), Box> { @@ -1830,8 +1798,8 @@ mod tests { #[test] fn test_gemm_u8i8_i32() -> Result<(), Box> { for gemm in all_gemms::() { - let mut rng = ReducedRangeRng::new(gemm.may_saturate()); - test_gemm_various_input_sizes(Some(&gemm), None, Some(&mut || rng.next_i8()))?; + let mut rng = ReducedRangeRng::new(gemm.may_saturate(), 1234); + test_gemm_various_input_sizes(Some(&gemm), None, Some(&mut || rng.next()))?; } Ok(()) } @@ -1866,11 +1834,11 @@ mod tests { for gemm in all_gemms::() { let mut lhs_rng = XorShiftRng::new(1234); - let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate()); + let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate(), 5678); for Case { m, n, k } in cases { let a = NdTensor::::rand([m, k], &mut lhs_rng); - let b = NdTensor::::from_simple_fn([k, n], || rhs_rng.next_i8()); + let b = NdTensor::::rand([k, n], &mut rhs_rng); let a_zero_point: Vec<_> = (0..a.rows()).map(|x| x as u8).collect(); let b_zero_point: Vec<_> = (0..b.cols()).map(|x| x as i8).collect(); @@ -1962,11 +1930,11 @@ mod tests { for gemm in all_gemms::() { let mut lhs_rng = XorShiftRng::new(1234); - let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate()); + let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate(), 5678); for &Case { k, n } in &cases { let a = NdTensor::::rand([1, k], &mut lhs_rng); - let mut b = NdTensor::::from_simple_fn([n, k], || rhs_rng.next_i8()); + let mut b = NdTensor::::rand([n, k], &mut rhs_rng); // Transpose the input B matrix. This will alter the row and column // strides and shapes, but not re-order the data. diff --git a/src/gemm/reduced_range_rng.rs b/src/gemm/reduced_range_rng.rs new file mode 100644 index 00000000..bda0c20f --- /dev/null +++ b/src/gemm/reduced_range_rng.rs @@ -0,0 +1,76 @@ +use rten_tensor::rng::XorShiftRng; +use rten_tensor::RandomSource; + +/// Random number generator which produces values with an optionally reduced +/// range. +/// +/// This works around an issue under AVX2 where the `vpmaddubsw` instruction +/// can encounter saturation when adding two signed 16-bit values into a +/// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x +/// i8` multiplication. By limiting the range of either the u8 or i8 input, +/// saturation is avoided. This issue does not affect the VNNI instruction +/// used on newer x64 systems. It also does not affect Arm. +/// +/// To match the behavior in ONNX Runtime's quantizer when +/// `reduce_range=True` is enabled, the range of whichever input are the +/// weights (usually the RHS) should be limited. +/// +/// To avoid saturation we require `i16::MIN >= u8_val * i8_val * 2 <= +/// i16::MAX`. A suitable choice is to use i7/u7 values with ranges [-64, +/// 63] and [0, 127]. +/// +/// See also https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html. +pub struct ReducedRangeRng { + reduce_range: bool, + rng: XorShiftRng, +} + +impl ReducedRangeRng { + pub fn new(reduce_range: bool, seed: u64) -> Self { + Self { + rng: XorShiftRng::new(seed), + reduce_range, + } + } +} + +impl RandomSource for ReducedRangeRng { + /// Return a random value in `[-64, 63]` (the i7 range). + fn next(&mut self) -> i8 { + if self.reduce_range { + ((self.rng.next_u64() % 128) as i16 - 64i16) as i8 + } else { + self.rng.next_u64() as i8 + } + } +} + +impl RandomSource for ReducedRangeRng { + /// Return a random value in `[0, 127]` (the u7 range). + fn next(&mut self) -> u8 { + if self.reduce_range { + (self.rng.next_u64() % 128) as u8 + } else { + self.rng.next_u64() as u8 + } + } +} + +#[cfg(test)] +mod tests { + use rten_tensor::RandomSource; + + use super::ReducedRangeRng; + + #[test] + fn test_reduced_range_rng() { + let mut rng = ReducedRangeRng::new(true, 1234); + for _ in 0..100 { + let x: i8 = rng.next(); + assert!(x >= -64 && x <= 63); + + let x: u8 = rng.next(); + assert!(x <= 127); + } + } +}