Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AVX-512 implementations of SimdFloat, SimdInt. Use them for GEMM. #65

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ crate-type = ["lib", "cdylib"]

[features]
# Use AVX-512 instructions if available. Requires nightly Rust for AVX-512 intrinsics.
avx512 = []
avx512 = ["rten-vecmath/avx512"]
# Generate WebAssembly API using wasm-bindgen.
wasm_api = []

Expand Down
3 changes: 3 additions & 0 deletions rten-vecmath/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ crate-type = ["lib"]
# See comments about `needless_range_loop` in root Cargo.toml.
needless_range_loop = "allow"
manual_memcpy = "allow"

[features]
avx512 = []
6 changes: 6 additions & 0 deletions rten-vecmath/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
//!
//! See the source code for comments on accuracy.

#![cfg_attr(
feature = "avx512",
feature(stdarch_x86_avx512),
feature(avx512_target_feature)
)]

mod erf;
mod exp;
pub mod simd_vec;
Expand Down
164 changes: 164 additions & 0 deletions rten-vecmath/src/simd_vec/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,167 @@ impl SimdFloat for __m256 {
_mm_prefetch(data as *const i8, _MM_HINT_ET0);
}
}

#[cfg(feature = "avx512")]
use std::arch::x86_64::{
__m512, __m512i, __mmask16, _mm512_abs_ps, _mm512_add_epi32, _mm512_add_ps,
_mm512_castsi512_ps, _mm512_cmp_epi32_mask, _mm512_cmp_ps_mask, _mm512_cvttps_epi32,
_mm512_div_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_loadu_si512, _mm512_mask_blend_epi32,
_mm512_mask_blend_ps, _mm512_max_ps, _mm512_mul_ps, _mm512_set1_epi32, _mm512_set1_ps,
_mm512_setzero_si512, _mm512_sllv_epi32, _mm512_storeu_ps, _mm512_storeu_si512,
_mm512_sub_epi32, _mm512_sub_ps, _MM_CMPINT_LT,
};

#[cfg(feature = "avx512")]
impl SimdInt for __m512i {
type Float = __m512;
type Mask = __mmask16;

const LEN: usize = 16;

#[inline]
unsafe fn zero() -> Self {
_mm512_setzero_si512()
}

#[inline]
unsafe fn splat(val: i32) -> Self {
_mm512_set1_epi32(val)
}

#[inline]
unsafe fn gt(self, other: Self) -> Self::Mask {
_mm512_cmp_epi32_mask(other, self, _MM_CMPINT_LT)
}

#[inline]
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
_mm512_mask_blend_epi32(mask, self, other)
}

#[inline]
unsafe fn add(self, rhs: Self) -> Self {
_mm512_add_epi32(self, rhs)
}

#[inline]
unsafe fn sub(self, rhs: Self) -> Self {
_mm512_sub_epi32(self, rhs)
}

#[inline]
unsafe fn shl<const COUNT: i32>(self) -> Self {
let count = Self::splat(COUNT);
_mm512_sllv_epi32(self, count)
}

#[inline]
unsafe fn reinterpret_as_float(self) -> Self::Float {
_mm512_castsi512_ps(self)
}

#[inline]
unsafe fn load(ptr: *const i32) -> Self {
_mm512_loadu_si512(ptr)
}

#[inline]
unsafe fn store(self, ptr: *mut i32) {
_mm512_storeu_si512(ptr, self)
}
}

#[cfg(feature = "avx512")]
impl SimdFloat for __m512 {
type Int = __m512i;
type Mask = __mmask16;

const LEN: usize = 16;

#[inline]
unsafe fn splat(val: f32) -> Self {
_mm512_set1_ps(val)
}

#[inline]
unsafe fn abs(self) -> Self {
_mm512_abs_ps(self)
}

#[inline]
unsafe fn mul_add(self, a: Self, b: Self) -> Self {
_mm512_fmadd_ps(self, a, b)
}

#[inline]
unsafe fn sub(self, rhs: Self) -> Self {
_mm512_sub_ps(self, rhs)
}

#[inline]
unsafe fn add(self, rhs: Self) -> Self {
_mm512_add_ps(self, rhs)
}

#[inline]
unsafe fn to_int_trunc(self) -> Self::Int {
_mm512_cvttps_epi32(self)
}

#[inline]
unsafe fn mul(self, rhs: Self) -> Self {
_mm512_mul_ps(self, rhs)
}

#[inline]
unsafe fn div(self, rhs: Self) -> Self {
_mm512_div_ps(self, rhs)
}

#[inline]
unsafe fn ge(self, rhs: Self) -> Self::Mask {
_mm512_cmp_ps_mask(self, rhs, _CMP_GE_OQ)
}

#[inline]
unsafe fn le(self, rhs: Self) -> Self::Mask {
_mm512_cmp_ps_mask(self, rhs, _CMP_LE_OQ)
}

#[inline]
unsafe fn lt(self, rhs: Self) -> Self::Mask {
_mm512_cmp_ps_mask(self, rhs, _CMP_LT_OQ)
}

#[inline]
unsafe fn max(self, rhs: Self) -> Self {
_mm512_max_ps(self, rhs)
}

#[inline]
unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self {
_mm512_mask_blend_ps(mask, self, rhs)
}

#[inline]
unsafe fn load(ptr: *const f32) -> Self {
_mm512_loadu_ps(ptr)
}

#[inline]
unsafe fn store(self, ptr: *mut f32) {
_mm512_storeu_ps(ptr, self)
}

/// Prefetch the cache line containing `data`, for reading.
#[inline]
unsafe fn prefetch(data: *const f32) {
_mm_prefetch(data as *const i8, _MM_HINT_T0);
}

/// Prefetch the cache line containing `data`, for writing.
#[inline]
unsafe fn prefetch_write(data: *mut f32) {
_mm_prefetch(data as *const i8, _MM_HINT_ET0);
}
}
127 changes: 11 additions & 116 deletions src/gemm/kernels/x86_64.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::arch::x86_64::__m256;

#[cfg(feature = "avx512")]
use std::arch::x86_64::__m512;

use rten_tensor::Matrix;
use rten_vecmath::simd_vec::SimdFloat;

Expand Down Expand Up @@ -132,32 +135,9 @@ impl Kernel for Avx512Kernel {
false
}

unsafe fn kernel(
tile_ptr: *mut f32,
tile_row_stride: usize,
a: &[f32],
b: &[f32],
depth: usize,
alpha: f32,
beta: f32,
) {
Self::kernel_avx_512(tile_ptr, tile_row_stride, a, b, depth, alpha, beta)
}

#[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512vl")]
unsafe fn gemv_kernel(out: &mut [f32], a: &[f32], b: Matrix, alpha: f32, beta: f32) {
// Re-use the AVX2 / FMA kernel because rten_vecmath doesn't provide
// AVX-512 implementations for `SimdFloat` yet.
FmaKernel::gemv_kernel(out, a, b, alpha, beta);
}
}

#[cfg(feature = "avx512")]
impl Avx512Kernel {
#[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512vl")]
unsafe fn kernel_avx_512(
unsafe fn kernel(
tile_ptr: *mut f32,
tile_row_stride: usize,
a: &[f32],
Expand All @@ -166,102 +146,17 @@ impl Avx512Kernel {
alpha: f32,
beta: f32,
) {
use core::arch::x86_64::{
__m512, _mm512_add_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_mul_ps, _mm512_set1_ps,
_mm512_setzero_ps, _mm512_storeu_ps, _mm_prefetch, _MM_HINT_ET0, _MM_HINT_T0,
};
use std::mem::size_of;

const MR: usize = Avx512Kernel::MR;
const NR: usize = Avx512Kernel::NR;
const NR_REGS: usize = NR / <__m512 as SimdFloat>::LEN;

const REG_SIZE: usize = size_of::<__m512>() / size_of::<f32>();
const NR_REGS: usize = NR / REG_SIZE;
assert!(NR % REG_SIZE == 0);

// Check that buffer accesses below are going to be valid.
assert!(a.len() >= depth * MR);
assert!(b.len() >= depth * NR);

let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();

let mut tmp = [[_mm512_setzero_ps(); NR_REGS]; MR];
let mut b_rows = [_mm512_setzero_ps(); NR_REGS];

// Perform first `depth - 1` outer product updates.
for k in 0..depth - 1 {
let a_off = k * MR;
let b_off = k * NR;

// Prefetch B for the next iteration.
_mm_prefetch(b_ptr.add((k + 1) * NR) as *const i8, _MM_HINT_T0);

for i in 0..NR_REGS {
b_rows[i] = _mm512_loadu_ps(b_ptr.add(b_off + i * REG_SIZE));
}

for i in 0..MR {
let a_val = *a_ptr.add(a_off + i);
let a_broadcast = _mm512_set1_ps(a_val);

for j in 0..NR_REGS {
tmp[i][j] = _mm512_fmadd_ps(a_broadcast, b_rows[j], tmp[i][j]);
}
}
}

// Prefetch output before the final computation loop.
for i in 0..MR {
_mm_prefetch(tile_ptr.add(tile_row_stride * i) as *const i8, _MM_HINT_ET0);
}

// Perform final outer product update.
let k = depth - 1;
let a_off = k * MR;
let b_off = k * NR;

for i in 0..NR_REGS {
b_rows[i] = _mm512_loadu_ps(b_ptr.add(b_off + i * REG_SIZE));
}

for i in 0..MR {
let a_val = *a_ptr.add(a_off + i);
let a_broadcast = _mm512_set1_ps(a_val);

for j in 0..NR_REGS {
tmp[i][j] = _mm512_fmadd_ps(a_broadcast, b_rows[j], tmp[i][j]);
}
}
simd_gemm::<__m512, MR, NR_REGS>(tile_ptr, tile_row_stride, a, b, depth, alpha, beta)
}

// Write to output tile.
if beta == 0. && alpha == 1. {
for i in 0..MR {
for j in 0..NR_REGS {
let out_ptr = tile_ptr.add(tile_row_stride * i + j * REG_SIZE);
_mm512_storeu_ps(out_ptr, tmp[i][j]);
}
}
} else if beta == 1. && alpha == 1. {
for i in 0..MR {
for j in 0..NR_REGS {
let out_ptr = tile_ptr.add(tile_row_stride * i + j * REG_SIZE);
let out_val = _mm512_add_ps(_mm512_loadu_ps(out_ptr), tmp[i][j]);
_mm512_storeu_ps(out_ptr, out_val);
}
}
} else {
let alpha_broadcast = _mm512_set1_ps(alpha);
let beta_broadcast = _mm512_set1_ps(beta);
for i in 0..MR {
for j in 0..NR_REGS {
let out_ptr = tile_ptr.add(tile_row_stride * i + j * REG_SIZE);
let out_val = _mm512_mul_ps(_mm512_loadu_ps(out_ptr), beta_broadcast);
let out_val = _mm512_fmadd_ps(tmp[i][j], alpha_broadcast, out_val);
_mm512_storeu_ps(out_ptr, out_val);
}
}
}
#[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512vl")]
unsafe fn gemv_kernel(out: &mut [f32], a: &[f32], b: Matrix, alpha: f32, beta: f32) {
simd_gemv::<__m512, 2>(out, a, b, alpha, beta);
}
}

Expand Down
Loading