Skip to content

Commit

Permalink
HMMM. This is where the float/double co-existence stuff starts to bec…
Browse files Browse the repository at this point in the history
…ome NOT NICE: code repetition at another level.

TODO: Better idea? --> Maybe namespaces and double kernel projects or compile via #define+#include-all-source-files hack collective source code pages? (Latter approach may become a problem when debugging, or will the compiler suite cope well? Will know only once done & tested.)

At least this is about the point where the function template solution stops to be useful. The run-time switching desire between float and double is doable, but not by using #ifdef/#else throughout, nor templating all the way up the TFloat usage calltree.
  • Loading branch information
GerHobbelt committed Jul 13, 2021
1 parent 5d16bab commit 8d40552
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
9 changes: 6 additions & 3 deletions src/arch/intsimdmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,12 @@ struct TESS_API IntSimdMatrix {
// RoundInputs above.
// The input will be over-read to the extent of the padding. There are no
// alignment requirements.
using MatrixDotVectorFunction = void (*)(int, int, const int8_t *, const TFloat *, const int8_t *,
TFloat *);
MatrixDotVectorFunction matrixDotVectorFunction;
using MatrixDotVectorFunctionFP32 = void (*)(int, int, const int8_t *, const float *, const int8_t *,
float *);
using MatrixDotVectorFunctionFP64 = void (*)(int, int, const int8_t *, const double *, const int8_t *,
double *);
MatrixDotVectorFunctionFP32 matrixDotVectorFunctionFP32;
MatrixDotVectorFunctionFP64 matrixDotVectorFunctionFP64;

// Number of 32 bit outputs held in each register.
int num_outputs_per_register_;
Expand Down
14 changes: 12 additions & 2 deletions src/lstm/weightmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,25 @@ void WeightMatrix::MatrixDotVector(const TFloat *u, TFloat *v) const {
MatrixDotVectorInternal(wf_, true, false, u, v);
}

void WeightMatrix::MatrixDotVector(const int8_t *u, TFloat *v) const {
void WeightMatrix::MatrixDotVector(const int8_t *u, float *v) const {
assert(int_mode_);
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction(wi_.dim1(), wi_.dim2(), &shaped_w_[0],
IntSimdMatrix::intSimdMatrix->matrixDotVectorFunctionFP32(wi_.dim1(), wi_.dim2(), &shaped_w_[0],
&scales_[0], u, v);
} else {
IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v);
}
}
void WeightMatrix::MatrixDotVector(const int8_t *u, double *v) const {
assert(int_mode_);
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->matrixDotVectorFunctionFP64(wi_.dim1(), wi_.dim2(), &shaped_w_[0],
&scales_[0], u, v);
} else {
IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v);
}
}


// MatrixDotVector for peep weights, MultiplyAccumulate adds the
// component-wise products of *this[0] and v to inout.
Expand Down

0 comments on commit 8d40552

Please sign in to comment.