Skip to content

Commit

Permalink
Optimize MP::parallelFor (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Oct 18, 2024
1 parent 618f7d7 commit ff9b942
Show file tree
Hide file tree
Showing 25 changed files with 318 additions and 448 deletions.
4 changes: 2 additions & 2 deletions src/libllm/benchmark_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ void benchmarkLlama(std::shared_ptr<llama::LlamaModel> model, int ctxLength, DTy

float tokenPerSec = benchmarkPromptForward(&r, model, 32000, ctxLength);
printf(
"llama2_7B %-8s %-8s prompt@len:%-5d %-7.1f\n",
"llama2_7B %-8s %-8s prefill@len:%-5d %-7.1f\n",
model->getCtx().getDevice().getName().c_str(),
weightType.toString().c_str(),
ctxLength,
tokenPerSec);

tokenPerSec = benchmarkTokenGeneration(&r, model, 32000, ctxLength);
printf(
"llama2_7B %-8s %-8s tokengen@ctx:%-5d %-7.1f\n",
"llama2_7B %-8s %-8s decode@ctx:%-5d %-7.1f\n",
model->getCtx().getDevice().getName().c_str(),
weightType.toString().c_str(),
ctxLength,
Expand Down
67 changes: 61 additions & 6 deletions src/libllm/cpu/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,50 @@ class TensorList {
return _shape[d].shape;
}
int getLength() const {
return static_cast<int>(_pointerList.size());
if (_basePtr) {
return _size;
} else {
return static_cast<int>(_pointerList.size());
}
}
lut::Span<T *const> getDataPtrList() const {
lut::Span<T *const> getDataPtrList() {
if (_basePtr && _pointerList.empty()) {
for (int i = 0; i < _size; ++i) {
_pointerList.push_back(_basePtr + i * _stride);
}
}
return lut::makeConstSpan(_pointerList);
}

TensorAccessor<T, DIM> getTensor(int index) const {
return TensorAccessor<T, DIM>(_shape, _pointerList[index]);
if (_basePtr) {
return TensorAccessor<T, DIM>(_shape, _basePtr + index * _stride);
} else {
return TensorAccessor<T, DIM>(_shape, _pointerList[index]);
}
}

private:
const TensorShape::Elem *_shape;
std::vector<T *> _pointerList;

int64_t _stride;
int _size;
T *_basePtr;

TensorList(const TensorShape::Elem *shape, std::vector<T *> &&pointerList)
: _shape(shape),
_pointerList(std::move(pointerList)) {
_pointerList(std::move(pointerList)),
_stride(0),
_size(0),
_basePtr(nullptr) {
}

TensorList(const TensorShape::Elem *shape, T *data, int size, int stride)
: _shape(shape),
_basePtr(data),
_size(size),
_stride(stride) {
}
};

Expand Down Expand Up @@ -175,7 +202,21 @@ TensorList<T, DIM> TensorList<T, DIM>::fromTensor(const Tensor &src) {
getDataPointerList<T, DIM>(src.getData<T>(), shape, pointerList);

const TensorShape::Elem *tensorShape = shape.data() + (shape.size() - DIM);
return TensorList<T, DIM>(tensorShape, std::move(pointerList));
if (src.isContiguous()) {
int numTensor = 1;
for (int i = 0; i < src.getDim() - DIM; ++i) {
numTensor *= src.getShape(i);
}

int stride = 1;
for (int i = 0; i < DIM; ++i) {
stride *= tensorShape[i].shape;
}

return TensorList<T, DIM>(tensorShape, src.getData<T>(), numTensor, stride);
} else {
return TensorList<T, DIM>(tensorShape, std::move(pointerList));
}
}

template<typename T, int DIM>
Expand All @@ -185,7 +226,21 @@ TensorList<T, DIM> TensorList<T, DIM>::fromTensor(Tensor &src) {
getDataPointerList<T, DIM>(src.getData<T>(), shape, pointerList);

const TensorShape::Elem *tensorShape = shape.data() + (shape.size() - DIM);
return TensorList<T, DIM>(tensorShape, std::move(pointerList));
if (src.isContiguous()) {
int numTensor = 1;
for (int i = 0; i < src.getDim() - DIM; ++i) {
numTensor *= src.getShape(i);
}

int stride = 1;
for (int i = 0; i < DIM; ++i) {
stride *= tensorShape[i].shape;
}

return TensorList<T, DIM>(tensorShape, src.getData<T>(), numTensor, stride);
} else {
return TensorList<T, DIM>(tensorShape, std::move(pointerList));
}
}

} // namespace cpu
Expand Down
24 changes: 11 additions & 13 deletions src/libllm/cpu/binary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,18 @@ Tensor binaryOpKernel(const Tensor &A, const Tensor &B, BinaryOp op) {
TensorList<T, 1> vC = TensorList<T, 1>::fromTensor(C);
CHECK(vA.getLength() == vB.getLength() && vC.getLength() == vB.getLength());

MP::parallelFor({vA.getLength()}, [&vA, &vB, &vC, op](MP::Partition partition) {
for (int j : partition.getRange()) {
TensorAccessor<const T, 1> a = vA.getTensor(j);
TensorAccessor<const T, 1> b = vB.getTensor(j);
TensorAccessor<T, 1> c = vC.getTensor(j);
MP::parallelFor(vA.getLength(), [&vA, &vB, &vC, op](MP::Context ctx) {
TensorAccessor<const T, 1> a = vA.getTensor(ctx.getBlockIdx());
TensorAccessor<const T, 1> b = vB.getTensor(ctx.getBlockIdx());
TensorAccessor<T, 1> c = vC.getTensor(ctx.getBlockIdx());

for (int i = 0; i < a.getShape(0); ++i) {
if (op == BinaryOp::ADD) {
c[i] = a[i] + b[i];
} else if (op == BinaryOp::MUL) {
c[i] = a[i] * b[i];
} else {
NOT_IMPL();
}
for (int i = 0; i < a.getShape(0); ++i) {
if (op == BinaryOp::ADD) {
c[i] = a[i] + b[i];
} else if (op == BinaryOp::MUL) {
c[i] = a[i] * b[i];
} else {
NOT_IMPL();
}
}
});
Expand Down
10 changes: 4 additions & 6 deletions src/libllm/cpu/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ void copyKernel(const Tensor &src, Tensor &dest) {
TensorList<T, 1> vC = TensorList<T, 1>::fromTensor(dest);
CHECK(vA.getLength() == vC.getLength());

MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) {
for (int j : partition.getRange()) {
TensorAccessor<const T, 1> a = vA.getTensor(j);
TensorAccessor<T, 1> c = vC.getTensor(j);
MP::parallelFor(vA.getLength(), [&vA, &vC](MP::Context ctx) {
TensorAccessor<const T, 1> a = vA.getTensor(ctx.getBlockIdx());
TensorAccessor<T, 1> c = vC.getTensor(ctx.getBlockIdx());

copyVector(c, a);
}
copyVector(c, a);
});
}

Expand Down
10 changes: 4 additions & 6 deletions src/libllm/cpu/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ namespace cpu {
template<typename T>
void fillKernel(Tensor A, float value) {
TensorList<T, 1> vC = TensorList<T, 1>::fromTensor(A);
MP::parallelFor({vC.getLength()}, [&vC, value](MP::Partition partition) {
for (int j : partition.getRange()) {
TensorAccessor<T, 1> c = vC.getTensor(j);
MP::parallelFor(vC.getLength(), [&vC, value](MP::Context ctx) {
TensorAccessor<T, 1> c = vC.getTensor(ctx.getBlockIdx());

for (int i = 0; i < c.getShape(0); ++i) {
c[i] = value;
}
for (int i = 0; i < c.getShape(0); ++i) {
c[i] = value;
}
});
}
Expand Down
18 changes: 8 additions & 10 deletions src/libllm/cpu/gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,15 @@ Tensor geluKernel(const Tensor &A) {
TensorList<T, 1> vC = TensorList<T, 1>::fromTensor(C);
CHECK(vA.getLength() == vC.getLength());

MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) {
for (int j : partition.getRange()) {
TensorAccessor<const T, 1> a = vA.getTensor(j);
TensorAccessor<T, 1> c = vC.getTensor(j);
MP::parallelFor(vA.getLength(), [&vA, &vC](MP::Context ctx) {
TensorAccessor<const T, 1> a = vA.getTensor(ctx.getBlockIdx());
TensorAccessor<T, 1> c = vC.getTensor(ctx.getBlockIdx());

int n = c.getShape(0);
for (int i = 0; i < n; ++i) {
float x = a[i];
x = x * 0.5f * (1.0f + erf(x / Sqrt2));
c[i] = T(x);
}
int n = c.getShape(0);
for (int i = 0; i < n; ++i) {
float x = a[i];
x = x * 0.5f * (1.0f + erf(x / Sqrt2));
c[i] = T(x);
}
});

Expand Down
17 changes: 9 additions & 8 deletions src/libllm/cpu/kernel/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,19 @@ PackedBlock<T> Pack(Block<T> src, Block<T> buf, int pack_size) {
PackedBlock<T> tgt{buf.data, pack_size, kc, numBlock};
CHECK(pack_size * numBlock * kc <= buf.numCols * buf.numRows);

auto closure = [src, tgt, pack_size](MP::Partition partition) {
for (int b : partition.getRange()) {
Block<T> srcBlock = src.sliceCol(b * pack_size, pack_size);
Block<T> tgtBlock = tgt.block(b);
srcBlock.copyTo(tgtBlock);
}
auto closure = [src, tgt, pack_size](MP::Context ctx) {
int b = ctx.getBlockIdx();
Block<T> srcBlock = src.sliceCol(b * pack_size, pack_size);
Block<T> tgtBlock = tgt.block(b);
srcBlock.copyTo(tgtBlock);
};

if (MODE == Mode::OMP) {
MP::parallelFor({numBlock}, closure);
MP::parallelFor(numBlock, closure);
} else {
closure(MP::Partition(lut::Range(numBlock)));
for (int i = 0; i < numBlock; ++i) {
closure(MP::Context(i, numBlock, 0));
}
}

int nc = src.numCols % pack_size;
Expand Down
19 changes: 9 additions & 10 deletions src/libllm/cpu/kernel/cvt.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ void cvt(int64_t n, const ElementA *x, int64_t offsetX, ElementC *y, int64_t off
int nr = (n - 1) % CvtMinElemPerThread + 1;
int numThreads = std::min(nb, MP::getMaxThreads());

MP::parallelFor({nb}, numThreads, [nb, nr, x, offsetX, y, offsetY](MP::Partition partition) {
for (int i : partition.getRange()) {
int ne = (i == nb - 1) ? nr : CvtMinElemPerThread;
cvtKernel<ElementA, ElementC, TYPE>(
ne,
x,
offsetX + i * CvtMinElemPerThread,
y,
offsetY + i * CvtMinElemPerThread);
}
MP::parallelFor(nb, [nb, nr, x, offsetX, y, offsetY](MP::Context ctx) {
int i = ctx.getBlockIdx();
int ne = (i == nb - 1) ? nr : CvtMinElemPerThread;
cvtKernel<ElementA, ElementC, TYPE>(
ne,
x,
offsetX + i * CvtMinElemPerThread,
y,
offsetY + i * CvtMinElemPerThread);
});
} else {
cvtKernel<ElementA, ElementC, TYPE>(n, x, offsetX, y, offsetY);
Expand Down
29 changes: 15 additions & 14 deletions src/libllm/cpu/kernel/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,25 +131,26 @@ class Gemm {
int lastNr = C.numCols % NR;
int lastMr = C.numRows % MR;

auto closure = [this, &A, &B, &C, mp, np, lastNr, lastMr](MP::Partition partition) {
for (int i : partition.getRange()) {
for (int j = 0; j < mp; ++j) {
int nr = (i != np - 1 || lastNr == 0) ? NR : lastNr;
int mr = (j != mp - 1 || lastMr == 0) ? MR : lastMr;

Block<T> Aj = A.block(j);
Block<T> Bi = B.block(i);
Block<T> Cji = C.slice(j * MR, i * NR, mr, nr);

microKernel(Aj, Bi, Cji);
}
auto closure = [this, &A, &B, &C, mp, np, lastNr, lastMr](MP::Context ctx) {
int i = ctx.getBlockIdx();
for (int j = 0; j < mp; ++j) {
int nr = (i != np - 1 || lastNr == 0) ? NR : lastNr;
int mr = (j != mp - 1 || lastMr == 0) ? MR : lastMr;

Block<T> Aj = A.block(j);
Block<T> Bi = B.block(i);
Block<T> Cji = C.slice(j * MR, i * NR, mr, nr);

microKernel(Aj, Bi, Cji);
}
};

if (MODE == Mode::OMP) {
MP::parallelFor({np}, closure);
MP::parallelFor(np, closure);
} else {
closure(MP::Partition(lut::Range(np)));
for (int i = 0; i < np; ++i) {
closure(MP::Context(i, np, 0));
}
}
}

Expand Down
34 changes: 14 additions & 20 deletions src/libllm/cpu/kernel/gemv.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@ void gemvContigousN(const GemvArgs<ElementA, ElementB, ElementC> &args) {
m * args.lda);
}
} else if (MODE == Mode::OMP) {
MP::parallelFor({args.M}, [args](MP::Partition partition) {
for (int m : partition.getRange()) {
args.y[m] += dotKernel<ElementC, ElementB, ElementA, TYPE>(
args.N,
args.x,
args.A,
m * args.lda);
}
MP::parallelFor(args.M, [args](MP::Context ctx) {
args.y[ctx.getBlockIdx()] += dotKernel<ElementC, ElementB, ElementA, TYPE>(
args.N,
args.x,
args.A,
ctx.getBlockIdx() * args.lda);
});
} else {
NOT_IMPL();
Expand All @@ -61,10 +59,7 @@ void gemvContigousN(const GemvArgs<ElementA, ElementB, ElementC> &args) {

template<typename ElementA, typename ElementB, typename ElementC, CpuMathBackend TYPE, Mode MODE>
void gemvContigousT(const GemvArgs<ElementA, ElementB, ElementC> &args) {
int mp = (args.M + GEMVMinRowsPerThread - 1) / GEMVMinRowsPerThread;
int numThreads = std::min(mp, MP::getMaxThreads());

if (MODE == Mode::SingleThread || numThreads <= 1) {
if (MODE == Mode::SingleThread) {
lut::c_ptr<float> y = alignedAlloc<float>(args.N);
memset(y.get(), 0, args.N * sizeof(float));

Expand All @@ -77,20 +72,19 @@ void gemvContigousT(const GemvArgs<ElementA, ElementB, ElementC> &args) {
} else if (MODE == Mode::OMP) {
// initialize numThreads y buffers.
// TODO: sfill
lut::c_ptr<float> ys = alignedAlloc<float>(args.N * numThreads);
memset(ys.get(), 0, args.N * numThreads * sizeof(float));
lut::c_ptr<float> ys = alignedAlloc<float>(args.N * MP::getMaxThreads());
memset(ys.get(), 0, args.N * MP::getMaxThreads() * sizeof(float));

// compute axpy.
MP::parallelFor({args.M}, numThreads, [args, &ys](MP::Partition partition) {
for (int m : partition.getRange()) {
float *py = ys.get() + partition.getPartitionIdx() * args.N;
axpyKernel<ElementB, ElementA, float, TYPE>(args.N, args.x[m], args.A, m * args.lda, py);
}
MP::parallelFor(args.M, [args, &ys](MP::Context ctx) {
int m = ctx.getBlockIdx();
float *py = ys.get() + ctx.getAttachedThreadIdx() * args.N;
axpyKernel<ElementB, ElementA, float, TYPE>(args.N, args.x[m], args.A, m * args.lda, py);
});

// accumulate ys.
// TODO: vAdd.
for (int p = 0; p < numThreads; ++p) {
for (int p = 0; p < MP::getMaxThreads(); ++p) {
float *py = ys.get() + p * args.N;
for (int i = 0; i < args.N; ++i) {
args.y[i] += py[i];
Expand Down
Loading

0 comments on commit ff9b942

Please sign in to comment.