From 8d405521b3b9a216ad94c4199a684434251ec135 Mon Sep 17 00:00:00 2001 From: Ger Hobbelt Date: Tue, 13 Jul 2021 15:15:08 +0200 Subject: [PATCH] HMMM. This is where the float/double co-existence stuff starts to become NOT NICE: code repetition at another level. TODO: Better idea? --> Maybe namespaces and double kernel projects or compile via #define+#include-all-source-files hack collective source code pages? (Latter approach may become a problem when debugging, or will the compiler suite cope well? Will know only once done & tested.) At least this is about the point where the function template solution stops to be useful. The run-time switching desire between float and double is doable, but not by using #ifdef/#else throughout, nor templating all the way up the TFloat usage calltree. --- src/arch/intsimdmatrix.h | 9 ++++++--- src/lstm/weightmatrix.cpp | 14 ++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/arch/intsimdmatrix.h b/src/arch/intsimdmatrix.h index 98a894f0ca..a2ee585def 100644 --- a/src/arch/intsimdmatrix.h +++ b/src/arch/intsimdmatrix.h @@ -97,9 +97,12 @@ struct TESS_API IntSimdMatrix { // RoundInputs above. // The input will be over-read to the extent of the padding. There are no // alignment requirements. - using MatrixDotVectorFunction = void (*)(int, int, const int8_t *, const TFloat *, const int8_t *, - TFloat *); - MatrixDotVectorFunction matrixDotVectorFunction; + using MatrixDotVectorFunctionFP32 = void (*)(int, int, const int8_t *, const float *, const int8_t *, + float *); + using MatrixDotVectorFunctionFP64 = void (*)(int, int, const int8_t *, const double *, const int8_t *, + double *); + MatrixDotVectorFunctionFP32 matrixDotVectorFunctionFP32; + MatrixDotVectorFunctionFP64 matrixDotVectorFunctionFP64; // Number of 32 bit outputs held in each register. int num_outputs_per_register_; diff --git a/src/lstm/weightmatrix.cpp b/src/lstm/weightmatrix.cpp index 46e10433ca..7b90aba23b 100644 --- a/src/lstm/weightmatrix.cpp +++ b/src/lstm/weightmatrix.cpp @@ -351,15 +351,25 @@ void WeightMatrix::MatrixDotVector(const TFloat *u, TFloat *v) const { MatrixDotVectorInternal(wf_, true, false, u, v); } -void WeightMatrix::MatrixDotVector(const int8_t *u, TFloat *v) const { +void WeightMatrix::MatrixDotVector(const int8_t *u, float *v) const { assert(int_mode_); if (IntSimdMatrix::intSimdMatrix) { - IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction(wi_.dim1(), wi_.dim2(), &shaped_w_[0], + IntSimdMatrix::intSimdMatrix->matrixDotVectorFunctionFP32(wi_.dim1(), wi_.dim2(), &shaped_w_[0], &scales_[0], u, v); } else { IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v); } } +void WeightMatrix::MatrixDotVector(const int8_t *u, double *v) const { + assert(int_mode_); + if (IntSimdMatrix::intSimdMatrix) { + IntSimdMatrix::intSimdMatrix->matrixDotVectorFunctionFP64(wi_.dim1(), wi_.dim2(), &shaped_w_[0], + &scales_[0], u, v); + } else { + IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v); + } +} + // MatrixDotVector for peep weights, MultiplyAccumulate adds the // component-wise products of *this[0] and v to inout.