Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand ReducedRangeRng and make it reusable outside GEMM tests #569

Merged
merged 1 commit into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 13 additions & 45 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,12 @@ fn gemm_block<LhsT, RhsT, OutT: GemmOutT>(
});
}

#[cfg(test)]
mod reduced_range_rng;

#[cfg(test)]
pub use reduced_range_rng::ReducedRangeRng;

#[cfg(test)]
mod tests {
use std::error::Error;
Expand All @@ -1397,7 +1403,7 @@ mod tests {

use super::{
BiasVector, ColOffsets, F32KernelType, GemmError, GemmExecutor, GemmInT, GemmInputA,
GemmInputB, GemmOutT, Im2Col, QuantParams, RowOffsets, WithKernel,
GemmInputB, GemmOutT, Im2Col, QuantParams, ReducedRangeRng, RowOffsets, WithKernel,
};

/// Scale a possibly non-float value by a float.
Expand Down Expand Up @@ -1628,44 +1634,6 @@ mod tests {
.filter_map(|kern_type| GemmExecutor::<L, R, O>::with_kernel(kern_type))
}

// Random number generator which produces values with an optionally reduced
// range.
//
// This works around an issue under AVX2 where the `vpmaddubsw` instruction
// can encounter saturation when adding two signed 16-bit values into a
// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x
// i8` multiplication. By limiting the range of either the u8 or i8 input,
// we can avoid saturation. This issue does not affect the VNNI instruction
// used on newer x64 systems.
//
// To match the workaround in ONNX Runtime's quantizer when
// `reduce_range=True` is enabled, the range of the RHS (ie. the weights)
// is limited.
//
// To avoid saturation we require `a_max * b_max * 2 <= i16::MAX`. This
// re-arranges to `b_max <= (i16::MAX / 2) / 255 <= 64`.
struct ReducedRangeRng {
reduce_range: bool,
rng: XorShiftRng,
}

impl ReducedRangeRng {
fn new(reduce_range: bool) -> Self {
Self {
rng: XorShiftRng::new(1234),
reduce_range,
}
}

fn next_i8(&mut self) -> i8 {
if self.reduce_range {
(self.rng.next_u64() % 65) as i8
} else {
self.rng.next_u64() as i8
}
}
}

// Simplest possible test case for easy debugging.
#[test]
fn test_simple_gemm_f32() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -1830,8 +1798,8 @@ mod tests {
#[test]
fn test_gemm_u8i8_i32() -> Result<(), Box<dyn Error>> {
for gemm in all_gemms::<u8, i8, i32>() {
let mut rng = ReducedRangeRng::new(gemm.may_saturate());
test_gemm_various_input_sizes(Some(&gemm), None, Some(&mut || rng.next_i8()))?;
let mut rng = ReducedRangeRng::new(gemm.may_saturate(), 1234);
test_gemm_various_input_sizes(Some(&gemm), None, Some(&mut || rng.next()))?;
}
Ok(())
}
Expand Down Expand Up @@ -1866,11 +1834,11 @@ mod tests {

for gemm in all_gemms::<u8, i8, i32>() {
let mut lhs_rng = XorShiftRng::new(1234);
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate());
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate(), 5678);

for Case { m, n, k } in cases {
let a = NdTensor::<u8, 2>::rand([m, k], &mut lhs_rng);
let b = NdTensor::<i8, 2>::from_simple_fn([k, n], || rhs_rng.next_i8());
let b = NdTensor::<i8, 2>::rand([k, n], &mut rhs_rng);

let a_zero_point: Vec<_> = (0..a.rows()).map(|x| x as u8).collect();
let b_zero_point: Vec<_> = (0..b.cols()).map(|x| x as i8).collect();
Expand Down Expand Up @@ -1962,11 +1930,11 @@ mod tests {

for gemm in all_gemms::<u8, i8, i32>() {
let mut lhs_rng = XorShiftRng::new(1234);
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate());
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate(), 5678);

for &Case { k, n } in &cases {
let a = NdTensor::<u8, 2>::rand([1, k], &mut lhs_rng);
let mut b = NdTensor::<i8, 2>::from_simple_fn([n, k], || rhs_rng.next_i8());
let mut b = NdTensor::<i8, 2>::rand([n, k], &mut rhs_rng);

// Transpose the input B matrix. This will alter the row and column
// strides and shapes, but not re-order the data.
Expand Down
76 changes: 76 additions & 0 deletions src/gemm/reduced_range_rng.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use rten_tensor::rng::XorShiftRng;
use rten_tensor::RandomSource;

/// Random number generator which produces values with an optionally reduced
/// range.
///
/// This works around an issue under AVX2 where the `vpmaddubsw` instruction
/// can encounter saturation when adding two signed 16-bit values into a
/// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x
/// i8` multiplication. By limiting the range of either the u8 or i8 input,
/// saturation is avoided. This issue does not affect the VNNI instruction
/// used on newer x64 systems. It also does not affect Arm.
///
/// To match the behavior in ONNX Runtime's quantizer when
/// `reduce_range=True` is enabled, the range of whichever input are the
/// weights (usually the RHS) should be limited.
///
/// To avoid saturation we require `i16::MIN >= u8_val * i8_val * 2 <=
/// i16::MAX`. A suitable choice is to use i7/u7 values with ranges [-64,
/// 63] and [0, 127].
///
/// See also https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html.
pub struct ReducedRangeRng {
reduce_range: bool,
rng: XorShiftRng,
}

impl ReducedRangeRng {
pub fn new(reduce_range: bool, seed: u64) -> Self {
Self {
rng: XorShiftRng::new(seed),
reduce_range,
}
}
}

impl RandomSource<i8> for ReducedRangeRng {
/// Return a random value in `[-64, 63]` (the i7 range).
fn next(&mut self) -> i8 {
if self.reduce_range {
((self.rng.next_u64() % 128) as i16 - 64i16) as i8
} else {
self.rng.next_u64() as i8
}
}
}

impl RandomSource<u8> for ReducedRangeRng {
/// Return a random value in `[0, 127]` (the u7 range).
fn next(&mut self) -> u8 {
if self.reduce_range {
(self.rng.next_u64() % 128) as u8
} else {
self.rng.next_u64() as u8
}
}
}

#[cfg(test)]
mod tests {
use rten_tensor::RandomSource;

use super::ReducedRangeRng;

#[test]
fn test_reduced_range_rng() {
let mut rng = ReducedRangeRng::new(true, 1234);
for _ in 0..100 {
let x: i8 = rng.next();
assert!(x >= -64 && x <= 63);

let x: u8 = rng.next();
assert!(x <= 127);
}
}
}