Skip to content

Commit

Permalink
Merge pull request #469 from robertknight/vec-instance-norm
Browse files Browse the repository at this point in the history
Vectorize InstanceNormalization and BatchNormalization
  • Loading branch information
robertknight authored Dec 19, 2024
2 parents 82d8d7f + 2c21c45 commit 1fd414f
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 38 deletions.
6 changes: 5 additions & 1 deletion rten-simd/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ pub trait Simd: Copy + Sized {
type Mask: SimdMask;

/// The contents of a vector as an array.
type Array: std::ops::Index<usize, Output = Self::Elem>;
///
/// This type should always be `[Self::ELEM; Self::LEN]`. The `to_array`
/// method returns this associated type rather than a concrete array due to
/// const generics limitations.
type Array: Copy + std::fmt::Debug + std::ops::Index<usize, Output = Self::Elem>;

/// Combine elements of `self` and `rhs` according to a mask.
///
Expand Down
4 changes: 2 additions & 2 deletions rten-vecmath/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub use exp::{
exp, sigmoid, silu, vec_exp, vec_exp_in_place, vec_sigmoid, vec_sigmoid_in_place, vec_silu,
vec_silu_in_place,
};
pub use shift_scale::vec_shift_scale_in_place;
pub use shift_scale::{vec_shift_scale_bias, vec_shift_scale_in_place};
pub use softmax::{vec_softmax, vec_softmax_in_place};
pub use sum::{vec_sum, vec_sum_square};
pub use sum::{vec_sum, vec_sum_square, vec_sum_square_sub};
pub use tanh::{tanh, vec_tanh, vec_tanh_in_place};
84 changes: 83 additions & 1 deletion rten-vecmath/src/shift_scale.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,69 @@ pub fn vec_shift_scale_in_place(
dispatch(simd_op);
}

struct SimdShiftScaleBias<'a> {
data: &'a mut [f32],
x_bias: f32,
scale: f32,
bias: f32,
}

impl<'a> SimdOp for SimdShiftScaleBias<'a> {
type Output = &'a mut [f32];

#[inline(always)]
unsafe fn eval<S: SimdFloat>(self) -> Self::Output {
let Self {
data,
x_bias,
scale,
bias,
} = self;

let mut out_ptr = data.as_mut_ptr();
let mut n = data.len();

let x_bias_vec = S::splat(x_bias);
let scale_vec = S::splat(scale);
let bias_vec = S::splat(bias);

while n >= S::LEN {
let y = S::load(out_ptr)
.sub(x_bias_vec)
.mul_add(scale_vec, bias_vec);
y.store(out_ptr);

out_ptr = out_ptr.add(S::LEN);
n -= S::LEN;
}

if n > 0 {
let y = S::load_partial(out_ptr, n, 0.)
.sub(x_bias_vec)
.mul_add(scale_vec, bias_vec);
y.store_partial(out_ptr, n);
}

data
}
}

/// Shift and scale each element in the input.
///
/// This updates `xs` as `xs[i] = (xs[i] - x_bias) * scale + bias`.
pub fn vec_shift_scale_bias(xs: &mut [f32], x_bias: f32, scale: f32, bias: f32) {
let op = SimdShiftScaleBias {
data: xs,
x_bias,
scale,
bias,
};
dispatch(op);
}

#[cfg(test)]
mod tests {
use super::vec_shift_scale_in_place;
use super::{vec_shift_scale_bias, vec_shift_scale_in_place};

fn reference_shift_scale(
data: &mut [f32],
Expand All @@ -95,6 +155,12 @@ mod tests {
}
}

fn reference_shift_scale_bias(data: &mut [f32], x_bias: f32, scale: f32, bias: f32) {
for i in 0..data.len() {
data[i] = (data[i] - x_bias).mul_add(scale, bias);
}
}

#[test]
fn test_vec_shift_scale() {
let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect();
Expand All @@ -120,4 +186,20 @@ mod tests {

assert_eq!(actual, expected);
}

#[test]
fn test_vec_shift_scale_bias() {
let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect();
let x_bias = 0.123;
let scale = 0.456;
let bias = 0.89;

let mut expected = data.clone();
reference_shift_scale_bias(&mut expected, x_bias, scale, bias);

let mut actual = data.clone();
vec_shift_scale_bias(&mut actual, x_bias, scale, bias);

assert_eq!(actual, expected);
}
}
57 changes: 54 additions & 3 deletions rten-vecmath/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,74 @@ pub fn vec_sum_square(xs: &[f32]) -> f32 {
dispatch(op)
}

struct SimdSumSquareSub<'a> {
input: &'a [f32],
offset: f32,
}

impl SimdOp for SimdSumSquareSub<'_> {
type Output = f32;

#[inline(always)]
unsafe fn eval<S: SimdFloat>(self) -> Self::Output {
let offset_vec = S::splat(self.offset);
let vec_sum = simd_fold(
self.input.into(),
S::zero(),
#[inline(always)]
|sum, x| {
let x_offset = x.sub(offset_vec);
x_offset.mul_add(x_offset, sum)
},
// Padding value chosen so that `x - offset` is zero for unused
// positions in the final update, and thus the accumulator is not
// modified in those positions.
self.offset,
);
vec_sum.sum()
}
}

/// Compute the sum of squares of `xs - offset`.
///
/// This is a variant of [`vec_sum_square`] which subtracts a constant value
/// from each element before squaring it. A typical use case is to compute the
/// variance of a sequence, which is defined as `mean((X - x_mean)^2)`.
pub fn vec_sum_square_sub(xs: &[f32], offset: f32) -> f32 {
let op = SimdSumSquareSub { input: xs, offset };
dispatch(op)
}

#[cfg(test)]
mod tests {
use super::{vec_sum, vec_sum_square};
use super::{vec_sum, vec_sum_square, vec_sum_square_sub};

// Chosen to not be a multiple of vector size, so that tail handling is
// exercised.
const LEN: usize = 100;

#[test]
fn test_vec_sum() {
let xs: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
let expected_sum: f32 = xs.iter().sum();
let sum = vec_sum(&xs);
assert_eq!(sum, expected_sum);
}

#[test]
fn test_vec_sum_square() {
let xs: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
let expected_sum: f32 = xs.iter().copied().map(|x| x * x).sum();
let sum = vec_sum_square(&xs);
assert_eq!(sum, expected_sum);
}

#[test]
fn test_vec_sum_square_sub() {
let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
let mean = xs.iter().sum::<f32>() / xs.len() as f32;
let expected_sum: f32 = xs.iter().copied().map(|x| (x - mean) * (x - mean)).sum();
let sum = vec_sum_square_sub(&xs, mean);
assert_eq!(sum, expected_sum);
}
}
113 changes: 82 additions & 31 deletions src/ops/norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,64 @@ use std::mem::MaybeUninit;
use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use rten_vecmath::{vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square};
use rten_vecmath::{
vec_shift_scale_bias, vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square,
vec_sum_square_sub,
};

use crate::ops::static_dims;
use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
use crate::slice_reductions::slice_max;
use crate::tensor_pool::TensorPool;

struct NormalizeOptions {
/// Pre-computed mean of the input data.
mean: f32,

/// Pre-computed variance of the input data.
variance: f32,

/// Epsilon value used to avoid divide-by-zero in sqrt.
epsilon: f32,

/// Constant scale to multiply normalized value by.
scale: f32,

/// Constant bias to add to normalized value.
bias: f32,
}

/// Normalize the mean and variance of elements in `data` and apply a constant
/// scale and bias to the result.
///
/// ```text
/// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
/// ```
fn normalize_slice(data: &mut [f32], opts: NormalizeOptions) {
let NormalizeOptions {
mean,
variance,
epsilon,
scale,
bias,
} = opts;

// To avoid divisions in the vectorized loop, we re-arrange:
//
// ```
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
// ```
//
// As:
//
// ```
// scaled_std_dev_reciprocal = scale / (input_var + epsilon).sqrt()
// Y = (X - input_mean) * scaled_std_dev_reciprocal + bias
// ```
let scaled_std_dev_reciprocal = scale / (variance + epsilon).sqrt();
vec_shift_scale_bias(data, mean, scaled_std_dev_reciprocal, bias);
}

/// Perform in-place batch normalization on the `NC*` tensor `out`.
///
/// See <https://github.com/onnx/onnx/blob/main/docs/Operators.md#batchnormalization>.
Expand All @@ -28,23 +79,26 @@ pub fn batch_norm_in_place(
let batch = input.size(0);
let chans = input.size(1);

input.make_contiguous();

for n in 0..batch {
for c in 0..chans {
let chan_mean = mean[[c]];
let chan_var = var[[c]];
let chan_scale = scale[[c]];
let chan_bias = bias[[c]];

// The batch norm formula, from the ONNX spec, is:
//
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
//
// It has been rewritten here to simplify the inner loop below.
let scaled_std_dev_reciprocal = chan_scale / (chan_var + epsilon).sqrt();

input
.slice_mut([n, c])
.apply(|el| (*el - chan_mean) * scaled_std_dev_reciprocal + chan_bias);
let mut chan = input.slice_mut([n, c]);
let chan_data = chan.data_mut().unwrap();
normalize_slice(
chan_data,
NormalizeOptions {
mean: chan_mean,
variance: chan_var,
epsilon,
scale: chan_scale,
bias: chan_bias,
},
);
}
}

Expand Down Expand Up @@ -164,32 +218,29 @@ pub fn instance_normalization_in_place(
));
}

// Needed for `slice_sum` below.
// Needed for `vec_*` ops below.
input.make_contiguous();

for n in 0..batch {
for c in 0..chans {
let mut slice = input.slice_mut([n, c]);
let chan_data = slice.data_mut().unwrap();

let chan_scale = scale[[c]];
let chan_bias = bias[[c]];
let chan_mean = vec_sum(slice.data().unwrap()) / slice.len() as f32;
let chan_variance = slice
.iter()
.map(|x| {
let diff = *x - chan_mean;
diff * diff
})
.sum::<f32>()
/ slice.len() as f32;

// The instance norm formula, from the ONNX spec, is:
//
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
//
// It has been rewritten here to optimize the inner loop.
let scaled_std_dev_reciprocal = chan_scale / (chan_variance + epsilon).sqrt();

slice.apply(|x| (*x - chan_mean) * scaled_std_dev_reciprocal + chan_bias)
let chan_mean = vec_sum(chan_data) / chan_data.len() as f32;
let chan_variance = vec_sum_square_sub(chan_data, chan_mean) / chan_data.len() as f32;

normalize_slice(
chan_data,
NormalizeOptions {
mean: chan_mean,
variance: chan_variance,
epsilon,
scale: chan_scale,
bias: chan_bias,
},
);
}
}

Expand Down

0 comments on commit 1fd414f

Please sign in to comment.