Skip to content

Commit

Permalink
reduce a little inv cost in mulEach
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Dec 13, 2024
1 parent 22f15b9 commit e728ca5
Showing 1 changed file with 66 additions and 78 deletions.
144 changes: 66 additions & 78 deletions src/msm_avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,16 +1008,11 @@ struct EcMT {
v2 = t1.isEqualAll(t2);
return kandb(v1, v2);
}
//#define SIGNED_TABLE // a little slower (32.1Mclk->32.4Mclk)
template<size_t bitLen, size_t w>
static void makeNAFtbl(V *idxTbl, VM *negTbl, const V a[2])
{
const Vec mask = vpbroadcastq((1<<w)-1);
#ifdef SIGNED_TABLE
(void)negTbl;
#else
const Vec Fu = vpbroadcastq(1<<w);
#endif
const Vec H = vpbroadcastq(1<<(w-1));
const Vec one = vpbroadcastq(1);
size_t pos = 0;
Expand All @@ -1027,84 +1022,69 @@ struct EcMT {
V idx = getUnitAt(a, 2, pos);
idx = vpandq(idx, mask);
idx = vpaddq(idx, CF);
#ifdef SIGNED_TABLE
V masked = vpandq(idx, mask);
VM v = vpcmpgtq(masked, H);
idxTbl[i] = masked; //vselect(negTbl[i], vpsubq(Fu, masked), masked); // idx >= H ? Fu - idx : idx;
CF = vpsrlq(idx, w);
CF = vpaddq(v, CF, one);
#else
V masked = vpandq(idx, mask);
negTbl[i] = vpcmpgtq(masked, H);
idxTbl[i] = vselect(negTbl[i], vpsubq(Fu, masked), masked); // idx >= H ? F - idx : idx;
CF = vpsrlq(idx, w);
CF = vpaddq(negTbl[i], CF, one);
#endif
pos += w;
}
}
template<bool isProj=true, bool mixed=false>
static void mulGLV(T& Q, const T& P, const mcl::msm::FrA *y)
static void mulGLV(T *Q, const T *P, const mcl::msm::FrA *y, size_t n = 1)
{
const size_t m = sizeof(V)/8;
const size_t w = 5;
const size_t halfN = (1<<(w-1))+1; // [0, 2^(w-1)]
#ifdef SIGNED_TABLE
const size_t tblN = 1<<w;
#else
const size_t tblN = halfN;
#endif
V a[2], b[2];
T tbl1[tblN], tbl2[tblN];
makeTable<isProj, mixed>(tbl1, halfN, P);
if (!isProj && mixed) normalizeJacobiVec<T>(tbl1+1, halfN-1);
for (size_t i = 0; i < halfN; i++) {
mulLambda(tbl2[i], tbl1[i]);
}
#ifdef SIGNED_TABLE
for (size_t i = halfN; i < tblN; i++) {
T::neg(tbl1[i], tbl1[tblN-i]);
T::neg(tbl2[i], tbl2[tblN-i]);
}
#endif
Unit *pa = (Unit*)a;
Unit *pb = (Unit*)b;
for (size_t i = 0; i < m; i++) {
Unit buf[4];
g_func.fr->fromMont(buf, y[i].v);
Unit aa[2], bb[2];
mcl::ec::local::optimizedSplitRawForBLS12_381(aa, bb, buf);
pa[i+m*0] = aa[0]; pa[i+m*1] = aa[1];
pb[i+m*0] = bb[0]; pb[i+m*1] = bb[1];
const size_t tblN = (1<<(w-1))+1; // [0, 2^(w-1)]

T tbl1s[tblN*n];
for (size_t k = 0; k < n; k++) {
makeTable<isProj, mixed>(tbl1s + tblN*k, tblN, P[k]);
}
const size_t bitLen = 128;
const size_t n = (bitLen + w-1)/w;
V aTbl[n], bTbl[n];
VM aNegTbl[n], bNegTbl[n];
makeNAFtbl<bitLen, w>(aTbl, aNegTbl, a);
makeNAFtbl<bitLen, w>(bTbl, bNegTbl, b);
if (!isProj && mixed) normalizeJacobiVec<T>(tbl1s+1, tblN*n-1);

for (size_t i = 0; i < n; i++) {
if (i > 0) for (size_t k = 0; k < w; k++) T::template dbl<isProj>(Q, Q);
const size_t pos = n-1-i;
for (size_t k = 0; k < n; k++) {
const T *tbl1 = &tbl1s[tblN*k];
T tbl2[tblN];
for (size_t i = 0; i < tblN; i++) {
mulLambda(tbl2[i], tbl1[i]);
}
V a[2], b[2];
Unit *pa = (Unit*)a;
Unit *pb = (Unit*)b;
for (size_t i = 0; i < m; i++) {
Unit buf[4];
g_func.fr->fromMont(buf, y[k*m+i].v);
Unit aa[2], bb[2];
mcl::ec::local::optimizedSplitRawForBLS12_381(aa, bb, buf);
pa[i+m*0] = aa[0]; pa[i+m*1] = aa[1];
pb[i+m*0] = bb[0]; pb[i+m*1] = bb[1];
}
const size_t bitLen = 128;
const size_t nw = (bitLen + w-1)/w;
V aTbl[nw], bTbl[nw];
VM aNegTbl[nw], bNegTbl[nw];
makeNAFtbl<bitLen, w>(aTbl, aNegTbl, a);
makeNAFtbl<bitLen, w>(bTbl, bNegTbl, b);

T t;
V idx = bTbl[pos];
t.gather(tbl2, idx);
#ifndef SIGNED_TABLE
t.y = F::select(bNegTbl[pos], t.y.neg(), t.y);
#endif
if (i == 0) {
Q = t;
} else {
add<isProj, mixed>(Q, Q, t);
for (size_t i = 0; i < nw; i++) {
if (i > 0) for (size_t j = 0; j < w; j++) T::template dbl<isProj>(Q[k], Q[k]);
const size_t pos = nw-1-i;

T t;
V idx = bTbl[pos];
t.gather(tbl2, idx);
t.y = F::select(bNegTbl[pos], t.y.neg(), t.y);
if (i == 0) {
Q[k] = t;
} else {
add<isProj, mixed>(Q[k], Q[k], t);
}
idx = aTbl[pos];
t.gather(tbl1, idx);
t.y = F::select(aNegTbl[pos], t.y.neg(), t.y);
add<isProj, mixed>(Q[k], Q[k], t);
}
idx = aTbl[pos];
t.gather(tbl1, idx);
#ifndef SIGNED_TABLE
t.y = F::select(aNegTbl[pos], t.y.neg(), t.y);
#endif
add<isProj, mixed>(Q, Q, t);
}
}
};
Expand Down Expand Up @@ -1509,22 +1489,30 @@ void mulEachAVX512(Unit *_x, const Unit *_y, size_t n)
mcl::msm::G1A *x = (mcl::msm::G1A*)_x;
const mcl::msm::FrA *y = (const mcl::msm::FrA*)_y;
if (!isProj && mixed) g_func.normalizeVecG1(x, x, n);
const size_t u = 4;
const size_t q = n / u;
#if 1
// 30.6Mclk at n=1024
for (size_t i = 0; i < n; i += 16) {
EcMA P;
P.setG1A(x+i, isProj);
EcMA::mulGLV<isProj, mixed>(P, P, y+i);
P.getG1A(x+i, isProj);
}
typedef EcMA V;
#else
for (size_t i = 0; i < n; i += 8) {
EcM P;
typedef EcM V;
#endif
const size_t m = sizeof(V)/(sizeof(FpM)*3)*8;
for (size_t i = 0; i < n; i += m*u) {
V P[u];
for (size_t k = 0; k < u; k++) {
P[k].setG1A(x+i+k*m, isProj);
}
V::mulGLV<isProj, mixed>(P, P, y+i, u);
for (size_t k = 0; k < u; k++) {
P[k].getG1A(x+i+k*m, isProj);
}
}
for (size_t i = q*m*u; i < n; i += m) {
V P;
P.setG1A(x+i, isProj);
EcM::mulGLV<isProj, mixed>(P, P, y+i);
V::mulGLV<isProj, mixed>(&P, &P, y+i);
P.getG1A(x+i, isProj);
}
#endif
}

bool initMsm(const mcl::CurveParam& cp, const mcl::msm::Func *func)
Expand Down

0 comments on commit e728ca5

Please sign in to comment.