Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFloat data type for neural network #3486

Merged
merged 4 commits into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/arch/dotproductavx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ namespace tesseract {

// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel AVX intrinsics to access the SIMD instruction set.
#if defined(FAST_FLOAT)
float DotProductAVX(const float *u, const float *v, int n) {
stweil marked this conversation as resolved.
Show resolved Hide resolved
const unsigned quot = n / 8;
const unsigned rem = n % 8;
__m256 t0 = _mm256_setzero_ps();
for (unsigned k = 0; k < quot; k++) {
__m256 f0 = _mm256_loadu_ps(u);
__m256 f1 = _mm256_loadu_ps(v);
f0 = _mm256_mul_ps(f0, f1);
t0 = _mm256_add_ps(t0, f0);
u += 8;
v += 8;
}
alignas(32) float tmp[8];
_mm256_store_ps(tmp, t0);
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
for (unsigned k = 0; k < rem; k++) {
result += *u++ * *v++;
}
return result;
}
#else
double DotProductAVX(const double *u, const double *v, int n) {
const unsigned quot = n / 8;
const unsigned rem = n % 8;
Expand Down Expand Up @@ -57,6 +79,7 @@ double DotProductAVX(const double *u, const double *v, int n) {
}
return result;
}
#endif

} // namespace tesseract.

Expand Down
29 changes: 29 additions & 0 deletions src/arch/dotproductfma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ namespace tesseract {

// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel FMA intrinsics to access the SIMD instruction set.
#if defined(FAST_FLOAT)
float DotProductFMA(const float *u, const float *v, int n) {
const unsigned quot = n / 16;
const unsigned rem = n % 16;
__m256 t0 = _mm256_setzero_ps();
__m256 t1 = _mm256_setzero_ps();
for (unsigned k = 0; k < quot; k++) {
__m256 f0 = _mm256_loadu_ps(u);
__m256 f1 = _mm256_loadu_ps(v);
t0 = _mm256_fmadd_ps(f0, f1, t0);
u += 8;
v += 8;
__m256 f2 = _mm256_loadu_ps(u);
__m256 f3 = _mm256_loadu_ps(v);
t1 = _mm256_fmadd_ps(f2, f3, t1);
u += 8;
v += 8;
}
t0 = _mm256_hadd_ps(t0, t1);
alignas(32) float tmp[8];
_mm256_store_ps(tmp, t0);
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
for (unsigned k = 0; k < rem; k++) {
result += *u++ * *v++;
}
return result;
}
#else
double DotProductFMA(const double *u, const double *v, int n) {
const unsigned quot = n / 8;
const unsigned rem = n % 8;
Expand All @@ -55,6 +83,7 @@ double DotProductFMA(const double *u, const double *v, int n) {
}
return result;
}
#endif

} // namespace tesseract.

Expand Down
64 changes: 63 additions & 1 deletion src/arch/dotproductsse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,66 @@ namespace tesseract {

// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
#if defined(FAST_FLOAT)
float DotProductSSE(const float *u, const float *v, int n) {
int max_offset = n - 4;
int offset = 0;
// Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and
// v, and multiplying them together in parallel.
__m128 sum = _mm_setzero_ps();
if (offset <= max_offset) {
offset = 4;
// Aligned load is reputedly faster but requires 16 byte aligned input.
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 &&
(reinterpret_cast<uintptr_t>(v) & 15) == 0) {
// Use aligned load.
sum = _mm_load_ps(u);
__m128 floats2 = _mm_load_ps(v);
// Multiply.
sum = _mm_mul_ps(sum, floats2);
while (offset <= max_offset) {
__m128 floats1 = _mm_load_ps(u + offset);
floats2 = _mm_load_ps(v + offset);
floats1 = _mm_mul_ps(floats1, floats2);
sum = _mm_add_ps(sum, floats1);
offset += 4;
}
} else {
// Use unaligned load.
sum = _mm_loadu_ps(u);
__m128 floats2 = _mm_loadu_ps(v);
// Multiply.
sum = _mm_mul_ps(sum, floats2);
while (offset <= max_offset) {
__m128 floats1 = _mm_loadu_ps(u + offset);
floats2 = _mm_loadu_ps(v + offset);
floats1 = _mm_mul_ps(floats1, floats2);
sum = _mm_add_ps(sum, floats1);
offset += 4;
}
}
}
// Add the 4 sums in sum horizontally.
#if 0
alignas(32) float tmp[4];
_mm_store_ps(tmp, sum);
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3];
#else
__m128 zero = _mm_setzero_ps();
// https://www.felixcloutier.com/x86/haddps
sum = _mm_hadd_ps(sum, zero);
sum = _mm_hadd_ps(sum, zero);
// Extract the low result.
float result = _mm_cvtss_f32(sum);
#endif
// Add on any left-over products.
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
#else
double DotProductSSE(const double *u, const double *v, int n) {
int max_offset = n - 2;
int offset = 0;
Expand All @@ -39,7 +99,8 @@ double DotProductSSE(const double *u, const double *v, int n) {
if (offset <= max_offset) {
offset = 2;
// Aligned load is reputedly faster but requires 16 byte aligned input.
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 && (reinterpret_cast<uintptr_t>(v) & 15) == 0) {
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 &&
(reinterpret_cast<uintptr_t>(v) & 15) == 0) {
// Use aligned load.
sum = _mm_load_pd(u);
__m128d floats2 = _mm_load_pd(v);
Expand Down Expand Up @@ -78,6 +139,7 @@ double DotProductSSE(const double *u, const double *v, int n) {
}
return result;
}
#endif

} // namespace tesseract.

Expand Down
2 changes: 1 addition & 1 deletion src/arch/intsimdmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ struct TESS_API IntSimdMatrix {
static const IntSimdMatrix *intSimdMatrix;
// Only available with NEON.
static const IntSimdMatrix intSimdMatrixNEON;
// Only available with AVX2 / SSE.
// Only available with AVX2 / AVX / FMA / SSE.
static const IntSimdMatrix intSimdMatrixAVX2;
static const IntSimdMatrix intSimdMatrixSSE;
};
Expand Down
Loading