Skip to content

Commit

Permalink
Merge pull request #569 from robertknight/reduced-range-rng
Browse files Browse the repository at this point in the history
Expand `ReducedRangeRng` and make it reusable outside GEMM tests
  • Loading branch information
robertknight authored Feb 2, 2025
2 parents 9d4a9cf + 89c0694 commit 7147f81
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 45 deletions.
58 changes: 13 additions & 45 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,12 @@ fn gemm_block<LhsT, RhsT, OutT: GemmOutT>(
});
}

#[cfg(test)]
mod reduced_range_rng;

#[cfg(test)]
pub use reduced_range_rng::ReducedRangeRng;

#[cfg(test)]
mod tests {
use std::error::Error;
Expand All @@ -1397,7 +1403,7 @@ mod tests {

use super::{
BiasVector, ColOffsets, F32KernelType, GemmError, GemmExecutor, GemmInT, GemmInputA,
GemmInputB, GemmOutT, Im2Col, QuantParams, RowOffsets, WithKernel,
GemmInputB, GemmOutT, Im2Col, QuantParams, ReducedRangeRng, RowOffsets, WithKernel,
};

/// Scale a possibly non-float value by a float.
Expand Down Expand Up @@ -1628,44 +1634,6 @@ mod tests {
.filter_map(|kern_type| GemmExecutor::<L, R, O>::with_kernel(kern_type))
}

// Random number generator which produces values with an optionally reduced
// range.
//
// This works around an issue under AVX2 where the `vpmaddubsw` instruction
// can encounter saturation when adding two signed 16-bit values into a
// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x
// i8` multiplication. By limiting the range of either the u8 or i8 input,
// we can avoid saturation. This issue does not affect the VNNI instruction
// used on newer x64 systems.
//
// To match the workaround in ONNX Runtime's quantizer when
// `reduce_range=True` is enabled, the range of the RHS (ie. the weights)
// is limited.
//
// To avoid saturation we require `a_max * b_max * 2 <= i16::MAX`. This
// re-arranges to `b_max <= (i16::MAX / 2) / 255 <= 64`.
struct ReducedRangeRng {
reduce_range: bool,
rng: XorShiftRng,
}

impl ReducedRangeRng {
fn new(reduce_range: bool) -> Self {
Self {
rng: XorShiftRng::new(1234),
reduce_range,
}
}

fn next_i8(&mut self) -> i8 {
if self.reduce_range {
(self.rng.next_u64() % 65) as i8
} else {
self.rng.next_u64() as i8
}
}
}

// Simplest possible test case for easy debugging.
#[test]
fn test_simple_gemm_f32() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -1830,8 +1798,8 @@ mod tests {
#[test]
fn test_gemm_u8i8_i32() -> Result<(), Box<dyn Error>> {
for gemm in all_gemms::<u8, i8, i32>() {
let mut rng = ReducedRangeRng::new(gemm.may_saturate());
test_gemm_various_input_sizes(Some(&gemm), None, Some(&mut || rng.next_i8()))?;
let mut rng = ReducedRangeRng::new(gemm.may_saturate(), 1234);
test_gemm_various_input_sizes(Some(&gemm), None, Some(&mut || rng.next()))?;
}
Ok(())
}
Expand Down Expand Up @@ -1866,11 +1834,11 @@ mod tests {

for gemm in all_gemms::<u8, i8, i32>() {
let mut lhs_rng = XorShiftRng::new(1234);
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate());
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate(), 5678);

for Case { m, n, k } in cases {
let a = NdTensor::<u8, 2>::rand([m, k], &mut lhs_rng);
let b = NdTensor::<i8, 2>::from_simple_fn([k, n], || rhs_rng.next_i8());
let b = NdTensor::<i8, 2>::rand([k, n], &mut rhs_rng);

let a_zero_point: Vec<_> = (0..a.rows()).map(|x| x as u8).collect();
let b_zero_point: Vec<_> = (0..b.cols()).map(|x| x as i8).collect();
Expand Down Expand Up @@ -1962,11 +1930,11 @@ mod tests {

for gemm in all_gemms::<u8, i8, i32>() {
let mut lhs_rng = XorShiftRng::new(1234);
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate());
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate(), 5678);

for &Case { k, n } in &cases {
let a = NdTensor::<u8, 2>::rand([1, k], &mut lhs_rng);
let mut b = NdTensor::<i8, 2>::from_simple_fn([n, k], || rhs_rng.next_i8());
let mut b = NdTensor::<i8, 2>::rand([n, k], &mut rhs_rng);

// Transpose the input B matrix. This will alter the row and column
// strides and shapes, but not re-order the data.
Expand Down
76 changes: 76 additions & 0 deletions src/gemm/reduced_range_rng.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use rten_tensor::rng::XorShiftRng;
use rten_tensor::RandomSource;

/// Random number generator which produces values with an optionally reduced
/// range.
///
/// This works around an issue under AVX2 where the `vpmaddubsw` instruction
/// can encounter saturation when adding two signed 16-bit values into a
/// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x
/// i8` multiplication. By limiting the range of either the u8 or i8 input,
/// saturation is avoided. This issue does not affect the VNNI instruction
/// used on newer x64 systems. It also does not affect Arm.
///
/// To match the behavior in ONNX Runtime's quantizer when
/// `reduce_range=True` is enabled, the range of whichever input are the
/// weights (usually the RHS) should be limited.
///
/// To avoid saturation we require `i16::MIN >= u8_val * i8_val * 2 <=
/// i16::MAX`. A suitable choice is to use i7/u7 values with ranges [-64,
/// 63] and [0, 127].
///
/// See also https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html.
pub struct ReducedRangeRng {
reduce_range: bool,
rng: XorShiftRng,
}

impl ReducedRangeRng {
pub fn new(reduce_range: bool, seed: u64) -> Self {
Self {
rng: XorShiftRng::new(seed),
reduce_range,
}
}
}

impl RandomSource<i8> for ReducedRangeRng {
/// Return a random value in `[-64, 63]` (the i7 range).
fn next(&mut self) -> i8 {
if self.reduce_range {
((self.rng.next_u64() % 128) as i16 - 64i16) as i8
} else {
self.rng.next_u64() as i8
}
}
}

impl RandomSource<u8> for ReducedRangeRng {
/// Return a random value in `[0, 127]` (the u7 range).
fn next(&mut self) -> u8 {
if self.reduce_range {
(self.rng.next_u64() % 128) as u8
} else {
self.rng.next_u64() as u8
}
}
}

#[cfg(test)]
mod tests {
use rten_tensor::RandomSource;

use super::ReducedRangeRng;

#[test]
fn test_reduced_range_rng() {
let mut rng = ReducedRangeRng::new(true, 1234);
for _ in 0..100 {
let x: i8 = rng.next();
assert!(x >= -64 && x <= 63);

let x: u8 = rng.next();
assert!(x <= 127);
}
}
}

0 comments on commit 7147f81

Please sign in to comment.