From ff9b9424340a18680837392a463521b9cc263f9c Mon Sep 17 00:00:00 2001 From: Xiaoyang Chen Date: Sat, 19 Oct 2024 01:11:59 +0800 Subject: [PATCH] Optimize MP::parallelFor (#102) --- src/libllm/benchmark_main.cc | 4 +- src/libllm/cpu/accessor.h | 67 +++++++++++++++++-- src/libllm/cpu/binary_op.cc | 24 ++++--- src/libllm/cpu/copy.cc | 10 ++- src/libllm/cpu/fill.cc | 10 ++- src/libllm/cpu/gelu.cc | 18 +++-- src/libllm/cpu/kernel/block.h | 17 ++--- src/libllm/cpu/kernel/cvt.h | 19 +++--- src/libllm/cpu/kernel/gemm.h | 29 ++++---- src/libllm/cpu/kernel/gemv.h | 34 ++++------ src/libllm/cpu/matmul.cc | 31 +++++---- src/libllm/cpu/normalizations.cc | 84 ++++++++++++------------ src/libllm/cpu/rand.cc | 26 +++++--- src/libllm/cpu/reduce.cc | 30 ++++----- src/libllm/cpu/repetition_penalty.cc | 44 ++++++------- src/libllm/cpu/softmax.cc | 32 +++++---- src/libllm/cpu/swiglu.cc | 21 +++--- src/libllm/cpu/transform.cc | 12 ++-- src/libllm/cpu/unfold.cc | 28 ++++---- src/libllm/mp.cc | 53 +-------------- src/libllm/mp.h | 44 ++++++------- src/libllm/mp_openmp.cc | 15 ++--- src/libllm/mp_thread_pool.cc | 14 ++-- src/lutil/range.h | 98 ---------------------------- src/lutil/thread_pool.h | 2 - 25 files changed, 318 insertions(+), 448 deletions(-) delete mode 100644 src/lutil/range.h diff --git a/src/libllm/benchmark_main.cc b/src/libllm/benchmark_main.cc index 12855702..6e631dd7 100644 --- a/src/libllm/benchmark_main.cc +++ b/src/libllm/benchmark_main.cc @@ -153,7 +153,7 @@ void benchmarkLlama(std::shared_ptr model, int ctxLength, DTy float tokenPerSec = benchmarkPromptForward(&r, model, 32000, ctxLength); printf( - "llama2_7B %-8s %-8s prompt@len:%-5d %-7.1f\n", + "llama2_7B %-8s %-8s prefill@len:%-5d %-7.1f\n", model->getCtx().getDevice().getName().c_str(), weightType.toString().c_str(), ctxLength, @@ -161,7 +161,7 @@ void benchmarkLlama(std::shared_ptr model, int ctxLength, DTy tokenPerSec = benchmarkTokenGeneration(&r, model, 32000, ctxLength); printf( - "llama2_7B %-8s %-8s tokengen@ctx:%-5d %-7.1f\n", + "llama2_7B %-8s %-8s decode@ctx:%-5d %-7.1f\n", model->getCtx().getDevice().getName().c_str(), weightType.toString().c_str(), ctxLength, diff --git a/src/libllm/cpu/accessor.h b/src/libllm/cpu/accessor.h index 7eb3ab16..95bd1798 100644 --- a/src/libllm/cpu/accessor.h +++ b/src/libllm/cpu/accessor.h @@ -123,23 +123,50 @@ class TensorList { return _shape[d].shape; } int getLength() const { - return static_cast(_pointerList.size()); + if (_basePtr) { + return _size; + } else { + return static_cast(_pointerList.size()); + } } - lut::Span getDataPtrList() const { + lut::Span getDataPtrList() { + if (_basePtr && _pointerList.empty()) { + for (int i = 0; i < _size; ++i) { + _pointerList.push_back(_basePtr + i * _stride); + } + } return lut::makeConstSpan(_pointerList); } TensorAccessor getTensor(int index) const { - return TensorAccessor(_shape, _pointerList[index]); + if (_basePtr) { + return TensorAccessor(_shape, _basePtr + index * _stride); + } else { + return TensorAccessor(_shape, _pointerList[index]); + } } private: const TensorShape::Elem *_shape; std::vector _pointerList; + int64_t _stride; + int _size; + T *_basePtr; + TensorList(const TensorShape::Elem *shape, std::vector &&pointerList) : _shape(shape), - _pointerList(std::move(pointerList)) { + _pointerList(std::move(pointerList)), + _stride(0), + _size(0), + _basePtr(nullptr) { + } + + TensorList(const TensorShape::Elem *shape, T *data, int size, int stride) + : _shape(shape), + _basePtr(data), + _size(size), + _stride(stride) { } }; @@ -175,7 +202,21 @@ TensorList TensorList::fromTensor(const Tensor &src) { getDataPointerList(src.getData(), shape, pointerList); const TensorShape::Elem *tensorShape = shape.data() + (shape.size() - DIM); - return TensorList(tensorShape, std::move(pointerList)); + if (src.isContiguous()) { + int numTensor = 1; + for (int i = 0; i < src.getDim() - DIM; ++i) { + numTensor *= src.getShape(i); + } + + int stride = 1; + for (int i = 0; i < DIM; ++i) { + stride *= tensorShape[i].shape; + } + + return TensorList(tensorShape, src.getData(), numTensor, stride); + } else { + return TensorList(tensorShape, std::move(pointerList)); + } } template @@ -185,7 +226,21 @@ TensorList TensorList::fromTensor(Tensor &src) { getDataPointerList(src.getData(), shape, pointerList); const TensorShape::Elem *tensorShape = shape.data() + (shape.size() - DIM); - return TensorList(tensorShape, std::move(pointerList)); + if (src.isContiguous()) { + int numTensor = 1; + for (int i = 0; i < src.getDim() - DIM; ++i) { + numTensor *= src.getShape(i); + } + + int stride = 1; + for (int i = 0; i < DIM; ++i) { + stride *= tensorShape[i].shape; + } + + return TensorList(tensorShape, src.getData(), numTensor, stride); + } else { + return TensorList(tensorShape, std::move(pointerList)); + } } } // namespace cpu diff --git a/src/libllm/cpu/binary_op.cc b/src/libllm/cpu/binary_op.cc index 84223b7b..945450ba 100644 --- a/src/libllm/cpu/binary_op.cc +++ b/src/libllm/cpu/binary_op.cc @@ -45,20 +45,18 @@ Tensor binaryOpKernel(const Tensor &A, const Tensor &B, BinaryOp op) { TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vB.getLength() && vC.getLength() == vB.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vB, &vC, op](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor b = vB.getTensor(j); - TensorAccessor c = vC.getTensor(j); + MP::parallelFor(vA.getLength(), [&vA, &vB, &vC, op](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor b = vB.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); - for (int i = 0; i < a.getShape(0); ++i) { - if (op == BinaryOp::ADD) { - c[i] = a[i] + b[i]; - } else if (op == BinaryOp::MUL) { - c[i] = a[i] * b[i]; - } else { - NOT_IMPL(); - } + for (int i = 0; i < a.getShape(0); ++i) { + if (op == BinaryOp::ADD) { + c[i] = a[i] + b[i]; + } else if (op == BinaryOp::MUL) { + c[i] = a[i] * b[i]; + } else { + NOT_IMPL(); } } }); diff --git a/src/libllm/cpu/copy.cc b/src/libllm/cpu/copy.cc index d48330bf..da02c19b 100644 --- a/src/libllm/cpu/copy.cc +++ b/src/libllm/cpu/copy.cc @@ -34,13 +34,11 @@ void copyKernel(const Tensor &src, Tensor &dest) { TensorList vC = TensorList::fromTensor(dest); CHECK(vA.getLength() == vC.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); + MP::parallelFor(vA.getLength(), [&vA, &vC](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); - copyVector(c, a); - } + copyVector(c, a); }); } diff --git a/src/libllm/cpu/fill.cc b/src/libllm/cpu/fill.cc index 1af87a7b..f5fbbfdc 100644 --- a/src/libllm/cpu/fill.cc +++ b/src/libllm/cpu/fill.cc @@ -32,13 +32,11 @@ namespace cpu { template void fillKernel(Tensor A, float value) { TensorList vC = TensorList::fromTensor(A); - MP::parallelFor({vC.getLength()}, [&vC, value](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor c = vC.getTensor(j); + MP::parallelFor(vC.getLength(), [&vC, value](MP::Context ctx) { + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); - for (int i = 0; i < c.getShape(0); ++i) { - c[i] = value; - } + for (int i = 0; i < c.getShape(0); ++i) { + c[i] = value; } }); } diff --git a/src/libllm/cpu/gelu.cc b/src/libllm/cpu/gelu.cc index 6db49e5f..233d1e7d 100644 --- a/src/libllm/cpu/gelu.cc +++ b/src/libllm/cpu/gelu.cc @@ -40,17 +40,15 @@ Tensor geluKernel(const Tensor &A) { TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vC.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); + MP::parallelFor(vA.getLength(), [&vA, &vC](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); - int n = c.getShape(0); - for (int i = 0; i < n; ++i) { - float x = a[i]; - x = x * 0.5f * (1.0f + erf(x / Sqrt2)); - c[i] = T(x); - } + int n = c.getShape(0); + for (int i = 0; i < n; ++i) { + float x = a[i]; + x = x * 0.5f * (1.0f + erf(x / Sqrt2)); + c[i] = T(x); } }); diff --git a/src/libllm/cpu/kernel/block.h b/src/libllm/cpu/kernel/block.h index 2565561d..4ff027ae 100644 --- a/src/libllm/cpu/kernel/block.h +++ b/src/libllm/cpu/kernel/block.h @@ -63,18 +63,19 @@ PackedBlock Pack(Block src, Block buf, int pack_size) { PackedBlock tgt{buf.data, pack_size, kc, numBlock}; CHECK(pack_size * numBlock * kc <= buf.numCols * buf.numRows); - auto closure = [src, tgt, pack_size](MP::Partition partition) { - for (int b : partition.getRange()) { - Block srcBlock = src.sliceCol(b * pack_size, pack_size); - Block tgtBlock = tgt.block(b); - srcBlock.copyTo(tgtBlock); - } + auto closure = [src, tgt, pack_size](MP::Context ctx) { + int b = ctx.getBlockIdx(); + Block srcBlock = src.sliceCol(b * pack_size, pack_size); + Block tgtBlock = tgt.block(b); + srcBlock.copyTo(tgtBlock); }; if (MODE == Mode::OMP) { - MP::parallelFor({numBlock}, closure); + MP::parallelFor(numBlock, closure); } else { - closure(MP::Partition(lut::Range(numBlock))); + for (int i = 0; i < numBlock; ++i) { + closure(MP::Context(i, numBlock, 0)); + } } int nc = src.numCols % pack_size; diff --git a/src/libllm/cpu/kernel/cvt.h b/src/libllm/cpu/kernel/cvt.h index 45d2f2ad..dec4789a 100644 --- a/src/libllm/cpu/kernel/cvt.h +++ b/src/libllm/cpu/kernel/cvt.h @@ -36,16 +36,15 @@ void cvt(int64_t n, const ElementA *x, int64_t offsetX, ElementC *y, int64_t off int nr = (n - 1) % CvtMinElemPerThread + 1; int numThreads = std::min(nb, MP::getMaxThreads()); - MP::parallelFor({nb}, numThreads, [nb, nr, x, offsetX, y, offsetY](MP::Partition partition) { - for (int i : partition.getRange()) { - int ne = (i == nb - 1) ? nr : CvtMinElemPerThread; - cvtKernel( - ne, - x, - offsetX + i * CvtMinElemPerThread, - y, - offsetY + i * CvtMinElemPerThread); - } + MP::parallelFor(nb, [nb, nr, x, offsetX, y, offsetY](MP::Context ctx) { + int i = ctx.getBlockIdx(); + int ne = (i == nb - 1) ? nr : CvtMinElemPerThread; + cvtKernel( + ne, + x, + offsetX + i * CvtMinElemPerThread, + y, + offsetY + i * CvtMinElemPerThread); }); } else { cvtKernel(n, x, offsetX, y, offsetY); diff --git a/src/libllm/cpu/kernel/gemm.h b/src/libllm/cpu/kernel/gemm.h index 14190f43..f23b9c0f 100644 --- a/src/libllm/cpu/kernel/gemm.h +++ b/src/libllm/cpu/kernel/gemm.h @@ -131,25 +131,26 @@ class Gemm { int lastNr = C.numCols % NR; int lastMr = C.numRows % MR; - auto closure = [this, &A, &B, &C, mp, np, lastNr, lastMr](MP::Partition partition) { - for (int i : partition.getRange()) { - for (int j = 0; j < mp; ++j) { - int nr = (i != np - 1 || lastNr == 0) ? NR : lastNr; - int mr = (j != mp - 1 || lastMr == 0) ? MR : lastMr; - - Block Aj = A.block(j); - Block Bi = B.block(i); - Block Cji = C.slice(j * MR, i * NR, mr, nr); - - microKernel(Aj, Bi, Cji); - } + auto closure = [this, &A, &B, &C, mp, np, lastNr, lastMr](MP::Context ctx) { + int i = ctx.getBlockIdx(); + for (int j = 0; j < mp; ++j) { + int nr = (i != np - 1 || lastNr == 0) ? NR : lastNr; + int mr = (j != mp - 1 || lastMr == 0) ? MR : lastMr; + + Block Aj = A.block(j); + Block Bi = B.block(i); + Block Cji = C.slice(j * MR, i * NR, mr, nr); + + microKernel(Aj, Bi, Cji); } }; if (MODE == Mode::OMP) { - MP::parallelFor({np}, closure); + MP::parallelFor(np, closure); } else { - closure(MP::Partition(lut::Range(np))); + for (int i = 0; i < np; ++i) { + closure(MP::Context(i, np, 0)); + } } } diff --git a/src/libllm/cpu/kernel/gemv.h b/src/libllm/cpu/kernel/gemv.h index 6e0c15d5..e919d893 100644 --- a/src/libllm/cpu/kernel/gemv.h +++ b/src/libllm/cpu/kernel/gemv.h @@ -45,14 +45,12 @@ void gemvContigousN(const GemvArgs &args) { m * args.lda); } } else if (MODE == Mode::OMP) { - MP::parallelFor({args.M}, [args](MP::Partition partition) { - for (int m : partition.getRange()) { - args.y[m] += dotKernel( - args.N, - args.x, - args.A, - m * args.lda); - } + MP::parallelFor(args.M, [args](MP::Context ctx) { + args.y[ctx.getBlockIdx()] += dotKernel( + args.N, + args.x, + args.A, + ctx.getBlockIdx() * args.lda); }); } else { NOT_IMPL(); @@ -61,10 +59,7 @@ void gemvContigousN(const GemvArgs &args) { template void gemvContigousT(const GemvArgs &args) { - int mp = (args.M + GEMVMinRowsPerThread - 1) / GEMVMinRowsPerThread; - int numThreads = std::min(mp, MP::getMaxThreads()); - - if (MODE == Mode::SingleThread || numThreads <= 1) { + if (MODE == Mode::SingleThread) { lut::c_ptr y = alignedAlloc(args.N); memset(y.get(), 0, args.N * sizeof(float)); @@ -77,20 +72,19 @@ void gemvContigousT(const GemvArgs &args) { } else if (MODE == Mode::OMP) { // initialize numThreads y buffers. // TODO: sfill - lut::c_ptr ys = alignedAlloc(args.N * numThreads); - memset(ys.get(), 0, args.N * numThreads * sizeof(float)); + lut::c_ptr ys = alignedAlloc(args.N * MP::getMaxThreads()); + memset(ys.get(), 0, args.N * MP::getMaxThreads() * sizeof(float)); // compute axpy. - MP::parallelFor({args.M}, numThreads, [args, &ys](MP::Partition partition) { - for (int m : partition.getRange()) { - float *py = ys.get() + partition.getPartitionIdx() * args.N; - axpyKernel(args.N, args.x[m], args.A, m * args.lda, py); - } + MP::parallelFor(args.M, [args, &ys](MP::Context ctx) { + int m = ctx.getBlockIdx(); + float *py = ys.get() + ctx.getAttachedThreadIdx() * args.N; + axpyKernel(args.N, args.x[m], args.A, m * args.lda, py); }); // accumulate ys. // TODO: vAdd. - for (int p = 0; p < numThreads; ++p) { + for (int p = 0; p < MP::getMaxThreads(); ++p) { float *py = ys.get() + p * args.N; for (int i = 0; i < args.N; ++i) { args.y[i] += py[i]; diff --git a/src/libllm/cpu/matmul.cc b/src/libllm/cpu/matmul.cc index b3c624bc..111cd7ae 100644 --- a/src/libllm/cpu/matmul.cc +++ b/src/libllm/cpu/matmul.cc @@ -210,22 +210,21 @@ Tensor bmm(const Tensor &A, const Tensor &B) { const T *const *mBp = mB.getDataPtrList().data(); T *const *mCp = mC.getDataPtrList().data(); - MP::parallelFor({mA.getLength()}, [mAp, mBp, mCp, gemmArgs](MP::Partition partition) { - for (int i : partition.getRange()) { - callGemm( - gemmArgs.transA, - gemmArgs.transB, - gemmArgs.M, - gemmArgs.N, - gemmArgs.K, - mAp[i], - gemmArgs.lda, - mBp[i], - gemmArgs.ldb, - mCp[i], - gemmArgs.ldc, - kernel::Mode::SingleThread); - } + MP::parallelFor(mA.getLength(), [mAp, mBp, mCp, gemmArgs](MP::Context ctx) { + int i = ctx.getBlockIdx(); + callGemm( + gemmArgs.transA, + gemmArgs.transB, + gemmArgs.M, + gemmArgs.N, + gemmArgs.K, + mAp[i], + gemmArgs.lda, + mBp[i], + gemmArgs.ldb, + mCp[i], + gemmArgs.ldc, + kernel::Mode::SingleThread); }); return C; diff --git a/src/libllm/cpu/normalizations.cc b/src/libllm/cpu/normalizations.cc index 0344737f..92229cb0 100644 --- a/src/libllm/cpu/normalizations.cc +++ b/src/libllm/cpu/normalizations.cc @@ -44,25 +44,23 @@ Tensor rmsNormKernel(const Tensor &tensor, const Tensor &weight, float eps) { TensorAccessor w = weight; - MP::parallelFor({vA.getLength()}, [&vA, &vC, w, eps](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); - - double sum = 0.0; - for (int i = 0; i < a.getShape(0); ++i) { - double va = a[i]; - sum += va * va; - } - double mean = sum / a.getShape(0); - double rms = std::sqrt(mean + eps); - - // compute rms-norm - for (int i = 0; i < a.getShape(0); ++i) { - double va = a[i]; - double vw = w[i]; - c[i] = static_cast(a[i] * w[i] / rms); - } + MP::parallelFor(vA.getLength(), [&vA, &vC, w, eps](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); + + float sum = 0.0; + for (int i = 0; i < a.getShape(0); ++i) { + float va = a[i]; + sum += va * va; + } + float mean = sum / a.getShape(0); + float rms = std::sqrt(mean + eps); + + // compute rms-norm + for (int i = 0; i < a.getShape(0); ++i) { + float va = a[i]; + float vw = w[i]; + c[i] = static_cast(a[i] * w[i] / rms); } }); @@ -83,31 +81,29 @@ Tensor layerNormKernel(const Tensor &tensor, const Tensor &weight, const Tensor TensorAccessor w = weight; TensorAccessor b = bias; - MP::parallelFor({vA.getLength()}, [&vA, &vC, w, b, eps](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); - - double sum = 0.0f; - for (int i = 0; i < a.getShape(0); ++i) { - sum += a[i]; - } - double mean = sum / a.getShape(0); - - // var (unbiased) - sum = 0.0; - for (int i = 0; i < a.getShape(0); ++i) { - double d = a[i] - mean; - sum += d * d; - } - double var = sum / a.getShape(0); - double sd = sqrt(var + eps); - - // compute layer-norm - for (int i = 0; i < a.getShape(0); ++i) { - float elem = static_cast((a[i] - mean) / sd); - c[i] = elem * w[i] + b[i]; - } + MP::parallelFor(vA.getLength(), [&vA, &vC, w, b, eps](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); + + double sum = 0.0f; + for (int i = 0; i < a.getShape(0); ++i) { + sum += a[i]; + } + double mean = sum / a.getShape(0); + + // var (unbiased) + sum = 0.0; + for (int i = 0; i < a.getShape(0); ++i) { + double d = a[i] - mean; + sum += d * d; + } + double var = sum / a.getShape(0); + double sd = sqrt(var + eps); + + // compute layer-norm + for (int i = 0; i < a.getShape(0); ++i) { + float elem = static_cast((a[i] - mean) / sd); + c[i] = elem * w[i] + b[i]; } }); diff --git a/src/libllm/cpu/rand.cc b/src/libllm/cpu/rand.cc index 081635c2..d74f07ff 100644 --- a/src/libllm/cpu/rand.cc +++ b/src/libllm/cpu/rand.cc @@ -44,15 +44,23 @@ Tensor randFp32(lut::Span shape, lut::Random *generator, float min, f generator->fill(tensorData, min, max); } else { // if no generator specified, we could go parallel. - MP::parallelFor( - {static_cast(tensorData.size())}, - [&tensorData, min, max](MP::Partition partition) { - lut::Random random(time(nullptr) + partition.getPartitionIdx()); - for (int i : partition.getRange()) { - float nextR = random.nextFloat(); - tensorData[i] = min + (max - min) * nextR; - } - }); + std::vector rs; + lut::Random rseed; + for (int i = 0; i < MP::getMaxThreads(); ++i) { + rs.emplace_back(rseed.nextInt()); + } + + int blockSize = 1024; + int nb = static_cast((tensorData.size() + blockSize - 1) / blockSize); + MP::parallelFor(nb, [&tensorData, &rs, min, max, blockSize](MP::Context ctx) { + int64_t b = ctx.getBlockIdx(); + int64_t begin = b * blockSize; + int64_t end = std::min(b * blockSize + blockSize, static_cast(tensorData.size())); + for (int i = begin; i < end; ++i) { + float nextR = rs[ctx.getAttachedThreadIdx()].nextFloat(); + tensorData[i] = min + (max - min) * nextR; + } + }); } return x; diff --git a/src/libllm/cpu/reduce.cc b/src/libllm/cpu/reduce.cc index 4a5c405e..0c2eee28 100644 --- a/src/libllm/cpu/reduce.cc +++ b/src/libllm/cpu/reduce.cc @@ -74,24 +74,22 @@ Tensor reduceKernel(Tensor A) { TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vC.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); - - float accumulator = getReduceInitial(); - for (int i = 0; i < a.getShape(0); i++) { - if (REDUCE_TYPE == ReduceType::SUM) { - accumulator += a[i]; - } else if (REDUCE_TYPE == ReduceType::MAX) { - if (a[i] > accumulator) accumulator = a[i]; - } else { - NOT_IMPL(); - } + MP::parallelFor(vA.getLength(), [&vA, &vC](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); + + float accumulator = getReduceInitial(); + for (int i = 0; i < a.getShape(0); i++) { + if (REDUCE_TYPE == ReduceType::SUM) { + accumulator += a[i]; + } else if (REDUCE_TYPE == ReduceType::MAX) { + if (a[i] > accumulator) accumulator = a[i]; + } else { + NOT_IMPL(); } - - c[0] = accumulator; } + + c[0] = accumulator; }); return C; diff --git a/src/libllm/cpu/repetition_penalty.cc b/src/libllm/cpu/repetition_penalty.cc index f031f2ac..e0389365 100644 --- a/src/libllm/cpu/repetition_penalty.cc +++ b/src/libllm/cpu/repetition_penalty.cc @@ -37,33 +37,31 @@ void repetitionPenalty2DKernel(Tensor logits, Tensor history, float weight) { TensorList vH = TensorList::fromTensor(history); CHECK(vA.getLength() == vH.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vH, weight](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor h = vH.getTensor(j); + MP::parallelFor(vA.getLength(), [&vA, &vH, weight](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor h = vH.getTensor(ctx.getBlockIdx()); - // gather. Avoid the same logit penalizing twice. - std::vector scores(h.getShape(0)); - for (int i = 0; i < h.getShape(0); ++i) { - LongType logitsIdx = h[i]; - CHECK(logitsIdx < a.getShape(0)); + // gather. Avoid the same logit penalizing twice. + std::vector scores(h.getShape(0)); + for (int i = 0; i < h.getShape(0); ++i) { + LongType logitsIdx = h[i]; + CHECK(logitsIdx < a.getShape(0)); - T v = a[logitsIdx]; - if (v > 0) { - v /= weight; - } else if (v < 0) { - v *= weight; - } + T v = a[logitsIdx]; + if (v > 0) { + v /= weight; + } else if (v < 0) { + v *= weight; + } - scores[i] = v; - }; + scores[i] = v; + }; - // scatter - for (int i = 0; i < h.getShape(0); ++i) { - LongType logitsIdx = h[i]; - a[logitsIdx] = scores[i]; - }; - } + // scatter + for (int i = 0; i < h.getShape(0); ++i) { + LongType logitsIdx = h[i]; + a[logitsIdx] = scores[i]; + }; }); } diff --git a/src/libllm/cpu/softmax.cc b/src/libllm/cpu/softmax.cc index 746e1438..371051c4 100644 --- a/src/libllm/cpu/softmax.cc +++ b/src/libllm/cpu/softmax.cc @@ -36,24 +36,22 @@ Tensor softmaxKernel(Tensor A) { TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vC.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); + MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); - std::vector m(a.getShape(0) + 1); - std::vector d(a.getShape(0) + 1); - m[0] = -1e10; - d[0] = 0; - for (int i = 0; i < a.getShape(0); i++) { - T x = a[i]; - m[i + 1] = fmaxf(m[i], x); - d[i + 1] = d[i] * expf(m[i] - m[i + 1]) + expf(x - m[i + 1]); - } - for (int i = 0; i < a.getShape(0); i++) { - float x = a[i]; - c[i] = static_cast(expf(x - m[a.getShape(0)]) / d[a.getShape(0)]); - } + std::vector m(a.getShape(0) + 1); + std::vector d(a.getShape(0) + 1); + m[0] = -1e10; + d[0] = 0; + for (int i = 0; i < a.getShape(0); i++) { + T x = a[i]; + m[i + 1] = fmaxf(m[i], x); + d[i + 1] = d[i] * expf(m[i] - m[i + 1]) + expf(x - m[i + 1]); + } + for (int i = 0; i < a.getShape(0); i++) { + float x = a[i]; + c[i] = static_cast(expf(x - m[a.getShape(0)]) / d[a.getShape(0)]); } }); diff --git a/src/libllm/cpu/swiglu.cc b/src/libllm/cpu/swiglu.cc index 95d23afd..668fc712 100644 --- a/src/libllm/cpu/swiglu.cc +++ b/src/libllm/cpu/swiglu.cc @@ -40,18 +40,17 @@ Tensor swigluKernel(const Tensor &A) { TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vC.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); + MP::parallelFor(vA.getLength(), [&vA, &vC](MP::Context ctx) { + int j = ctx.getBlockIdx(); + TensorAccessor a = vA.getTensor(j); + TensorAccessor c = vC.getTensor(j); - int n = c.getShape(0); - for (int i = 0; i < n; ++i) { - T x = a[i]; - x *= 1.0f / (1 + expf(-x)); - x *= a[i + n]; - c[i] = x; - } + int n = c.getShape(0); + for (int i = 0; i < n; ++i) { + T x = a[i]; + x *= 1.0f / (1 + expf(-x)); + x *= a[i + n]; + c[i] = x; } }); diff --git a/src/libllm/cpu/transform.cc b/src/libllm/cpu/transform.cc index 85b4e6bc..fb6b8241 100644 --- a/src/libllm/cpu/transform.cc +++ b/src/libllm/cpu/transform.cc @@ -37,14 +37,12 @@ Tensor transformKernel(const Tensor &A, float alpha, float beta) { TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vC.getLength()); - MP::parallelFor({vA.getLength()}, [&vA, &vC, alpha, beta](MP::Partition partition) { - for (int j : partition.getRange()) { - TensorAccessor a = vA.getTensor(j); - TensorAccessor c = vC.getTensor(j); + MP::parallelFor(vA.getLength(), [&vA, &vC, alpha, beta](MP::Context ctx) { + TensorAccessor a = vA.getTensor(ctx.getBlockIdx()); + TensorAccessor c = vC.getTensor(ctx.getBlockIdx()); - for (int i = 0; i < a.getShape(0); ++i) { - c[i] = a[i] * static_cast(alpha) + static_cast(beta); - } + for (int i = 0; i < a.getShape(0); ++i) { + c[i] = a[i] * static_cast(alpha) + static_cast(beta); } }); diff --git a/src/libllm/cpu/unfold.cc b/src/libllm/cpu/unfold.cc index 64610384..02574b1b 100644 --- a/src/libllm/cpu/unfold.cc +++ b/src/libllm/cpu/unfold.cc @@ -40,24 +40,22 @@ void unfold1DKernel(const Tensor &src, Tensor &dest, int kernelSize, int stride) TensorAccessor vC = mC.getTensor(i); CHECK(vA.getShape(0) / stride == vC.getShape(0)); - MP::parallelFor({vC.getShape(0)}, [&vA, &vC, kernelSize, stride](MP::Partition partition) { + MP::parallelFor(vC.getShape(0), [&vA, &vC, kernelSize, stride](MP::Context ctx) { + int j = ctx.getBlockIdx(); int kernekIdxBegin = -(kernelSize / 2); int kernekIdxEnd = (kernelSize - 1) / 2; + int numChannels = vA.getShape(1); + int numInFrames = vA.getShape(0); - for (int j : partition.getRange()) { - int numChannels = vA.getShape(1); - int numInFrames = vA.getShape(0); - - for (int d = 0; d < numChannels; ++d) { - for (int k = kernekIdxBegin; k <= kernekIdxEnd; ++k) { - int srcIdx = j * stride + k; - int offset = k - kernekIdxBegin; - if (srcIdx < 0 || srcIdx >= numInFrames) { - // padding. - vC[j][d * kernelSize + offset] = 0.0f; - } else { - vC[j][d * kernelSize + offset] = vA[srcIdx][d]; - } + for (int d = 0; d < numChannels; ++d) { + for (int k = kernekIdxBegin; k <= kernekIdxEnd; ++k) { + int srcIdx = j * stride + k; + int offset = k - kernekIdxBegin; + if (srcIdx < 0 || srcIdx >= numInFrames) { + // padding. + vC[j][d * kernelSize + offset] = 0.0f; + } else { + vC[j][d * kernelSize + offset] = vA[srcIdx][d]; } } } diff --git a/src/libllm/mp.cc b/src/libllm/mp.cc index 4a54eb99..b58609ee 100644 --- a/src/libllm/mp.cc +++ b/src/libllm/mp.cc @@ -22,55 +22,4 @@ #include #include -#include "lutil/range.h" - -namespace libllm { - -lut::Range MP::splitRange(lut::Range range, int chunkIdx, int numChunks) { - int numel = (range.getEnd() - range.getBegin()) / range.getStep(); - int partitionSize = numel / numChunks; - int remain = numel % numChunks; - - int begin = chunkIdx * partitionSize; - begin += std::min(chunkIdx, remain); - int end = begin + partitionSize; - if (chunkIdx < remain) ++end; - - begin = begin * range.getStep() + range.getBegin(); - end = std::min(end * range.getStep() + range.getBegin(), range.getEnd()); - - return lut::Range(begin, end, range.getStep()); -} - -MP::Partition::Partition( - lut::Range range, - int partitionIdx, - int numPartitions, - int attachedThreadIdx) - : _range(range), - _partitionIdx(partitionIdx), - _numPartitions(numPartitions), - _attachedThreadIdx(attachedThreadIdx) { -} - -MP::Partition::Partition(lut::Range range) - : Partition(range, 0, 1, -1) { -} - -lut::Range MP::Partition::getRange() const { - return _range; -} - -int MP::Partition::getPartitionIdx() const { - return _partitionIdx; -} - -int MP::Partition::getNumPartitions() const { - return _numPartitions; -} - -int MP::Partition::getAttachedThreadIdx() const { - return _attachedThreadIdx; -} - -} // namespace libllm +namespace libllm {} // namespace libllm diff --git a/src/libllm/mp.h b/src/libllm/mp.h index 64e07bb3..1b903676 100644 --- a/src/libllm/mp.h +++ b/src/libllm/mp.h @@ -21,14 +21,13 @@ #include -#include "lutil/range.h" - namespace libllm { // wrapper for OpenMP or other implementations class MP { public: class Partition; + class Context; static void init(); static void destroy(); @@ -36,37 +35,34 @@ class MP { /// @brief split range into N parts and apply each part in the closure. N is the number of /// workers in the thread pool. - /// @param range the range. + /// @param numBlocks number of blocks. /// @param closure the closure. Since we need to invoke the closure multiple times, we use it /// by value here. - /// @param numThreads number of threads to use. -1 means using all threads in the pool. - static void parallelFor(lut::Range range, int numThreads, std::function closure); - static void parallelFor(lut::Range range, std::function closure); - - private: - /// @brief Split the range into N parts, and returns the i-th part. - /// @param range the range to split. - /// @param chunkIdx the i-th part to get. - /// @param numChunks number of parts to split (the N). - /// @return i-th part of the range after split. - static lut::Range splitRange(lut::Range range, int chunkIdx, int numChunks); + static void parallelFor(int numBlocks, std::function closure); }; /// @brief Store a partition info for a parallelFor function call. -class MP::Partition { +class MP::Context { public: - Partition(lut::Range range); - Partition(lut::Range range, int partitionIdx, int numPartitions, int attachedThreadIdx); + Context(int blockIdx, int numBlocks, int attachedThreadIdx) + : _blockIdx(blockIdx), + _numBlocks(numBlocks), + _attachedThreadIdx(attachedThreadIdx) { + } - lut::Range getRange() const; - int getPartitionIdx() const; - int getNumPartitions() const; - int getAttachedThreadIdx() const; + int getBlockIdx() const { + return _blockIdx; + } + int getNumBlocks() const { + return _numBlocks; + } + int getAttachedThreadIdx() const { + return _attachedThreadIdx; + } private: - lut::Range _range; - int _partitionIdx; - int _numPartitions; + int _blockIdx; + int _numBlocks; int _attachedThreadIdx; }; diff --git a/src/libllm/mp_openmp.cc b/src/libllm/mp_openmp.cc index 197644f8..7f025e8a 100644 --- a/src/libllm/mp_openmp.cc +++ b/src/libllm/mp_openmp.cc @@ -28,7 +28,6 @@ #include #include "lutil/log.h" -#include "lutil/range.h" #include "lutil/thread_pool.h" namespace libllm { @@ -44,17 +43,11 @@ int MP::getMaxThreads() { return omp_get_max_threads(); } -void MP::parallelFor(lut::Range range, int numThreads, std::function closure) { - int n = numThreads > 0 ? numThreads : getMaxThreads(); - -#pragma omp parallel for num_threads(n) - for (int i = 0; i < n; ++i) { - closure(Partition(splitRange(range, i, n), i, n, omp_get_thread_num())); +void MP::parallelFor(int numBlocks, std::function closure) { +#pragma omp parallel for num_threads(getMaxThreads()) schedule(dynamic, 1) + for (int i = 0; i < numBlocks; ++i) { + closure(Context(i, numBlocks, omp_get_thread_num())); } } -void MP::parallelFor(lut::Range range, std::function closure) { - return parallelFor(range, -1, closure); -} - } // namespace libllm diff --git a/src/libllm/mp_thread_pool.cc b/src/libllm/mp_thread_pool.cc index 69ea0637..c115643d 100644 --- a/src/libllm/mp_thread_pool.cc +++ b/src/libllm/mp_thread_pool.cc @@ -49,14 +49,16 @@ int MP::getMaxThreads() { return gThreadPoolMP->getNumThreads(); } -void MP::parallelFor(lut::Range range, int numThreads, std::function closure) { +void MP::parallelFor2(int numBlocks, std::function closure) { CHECK(gThreadPoolMP) << "call MP::parallelFor() before MP::init()"; - int n = numThreads > 0 ? numThreads : gThreadPoolMP->getNumThreads(); + int n = gThreadPoolMP->getNumThreads(); std::atomic numDone{0}; for (int i = 0; i < n; ++i) { - gThreadPoolMP->apply([range, closure, i, n, &numDone]() { - closure(Partition(splitRange(range, i, n), i, n, lut::ThreadPool::getThreadId())); + gThreadPoolMP->apply([numBlocks, closure, i, n, &numDone]() { + for (int j = i; j < numBlocks; j += n) { + closure(Context(j, numBlocks, lut::ThreadPool::getThreadId())); + } numDone.fetch_add(1); }); } @@ -66,8 +68,4 @@ void MP::parallelFor(lut::Range range, int numThreads, std::function closure) { - return parallelFor(range, -1, closure); -} - } // namespace libllm diff --git a/src/lutil/range.h b/src/lutil/range.h deleted file mode 100644 index e4e364c1..00000000 --- a/src/lutil/range.h +++ /dev/null @@ -1,98 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#pragma once - -#include - -#include - -namespace lut { - -class Range { - public: - class iterator { - public: - iterator(int begin, int end, int step) - : _current(begin), - _end(end), - _step(step) { - } - - int operator*() const { - return _current; - } - - iterator& operator++() { - _current += _step; - return *this; - } - - bool operator!=(const iterator& other) const { - // two iterator equals only when both are finished. - return !((_current >= _end) && (other._current >= other._end)); - } - - private: - int _current; - int _step; - int _end; - }; - - Range(int end) - : _begin(0), - _end(end), - _step(1) { - } - Range(int begin, int end) - : _begin(begin), - _end(end), - _step(1) { - } - Range(int begin, int end, int step) - : _begin(begin), - _end(end), - _step(step) { - } - - iterator begin() const { - return iterator(_begin, _end, _step); - } - - iterator end() const { - return iterator(_end, _end, _step); - } - - int getBegin() const { - return _begin; - } - int getEnd() const { - return _end; - } - int getStep() const { - return _step; - } - - private: - int _begin; - int _end; - int _step; -}; - -} // namespace lut diff --git a/src/lutil/thread_pool.h b/src/lutil/thread_pool.h index 3705553f..80ea2e38 100644 --- a/src/lutil/thread_pool.h +++ b/src/lutil/thread_pool.h @@ -22,8 +22,6 @@ #include #include -#include "lutil/range.h" - namespace lut { class ThreadPool {