From 8d1738761670073bb81730d6c99889c90a3988b0 Mon Sep 17 00:00:00 2001 From: Haruki Imai Date: Thu, 6 Feb 2025 00:02:40 -0500 Subject: [PATCH] Update fetchBatch() for serial exec. Signed-off-by: Haruki Imai --- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 15 +++++++-------- .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 14 +++++++------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index efac7c6fb3..c0ba82d023 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -888,15 +888,14 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms, ShapedType reducedType = type.clone(reducedShape); return fromWideNums(reducedType, [&](MutableArrayRef dstNums) { StridesRange<1> sRange(reducedShape, {reducedStrides}); + StridesRange<1> axesRange(axesShape, {axesStrides}); SmallVector, 4> batch; for (auto &idxoffs : sRange) batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0])); - StridesRange<1> axesRange(axesShape, {axesStrides}); - - auto fetchBatch = [&](size_t threadNumber) { + auto fetchBatch = [&](size_t threadNumber, bool parallel) { // retrun all data without spliting for sequential execution. - if (threadNumber == SIZE_MAX) + if (!parallel) return llvm::make_range(batch.begin(), batch.end()); // Each thread fetches the same batch size. The leftovers are set in the // threads with small thread number. @@ -916,10 +915,10 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms, batch.begin() + beginOffset, batch.begin() + endOffset); }; - auto work = [&](size_t threadNumber) { - auto batch = fetchBatch(threadNumber); + auto work = [&](size_t threadNumber, bool parallel = true) { + auto tile = fetchBatch(threadNumber, parallel); // Traverse and populate each element d in dstNums. - for (auto b : batch) { + for (auto b : tile) { WideNum &d = dstNums[b.first]; int64_t srcPos = b.second; // Traverse all the elements that reduce together into d. @@ -948,7 +947,7 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms, constexpr size_t minCount = 2000; size_t inputCount = batch.size() * axesRange.size(); if (inputCount < minCount) - work(SIZE_MAX); // Sequential + work(0, /*parallel*/ false); else parallelFor(ctx, 0, ctx->getNumThreads(), work); }); diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 8c9abe9831..6242da6139 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -249,11 +249,11 @@ class ElementsAttrBuilder { mlir::MLIRContext *ctx = disposablePool.getContext(); return [fun = std::move(fun), ctx]( llvm::MutableArrayRef data) -> void { - auto fetchBatch = [&](size_t threadNumber) { + auto fetchBatch = [&](size_t threadNumber, bool parallel) { // retrun all data without spliting for sequential execution. - if (threadNumber == SIZE_MAX) + if (!parallel) return llvm::make_range(data.begin(), data.end()); - // Each thread fetches the same batch size. The leftovers are set in the + // Each thread fetches the same data size. The leftovers are set in the // threads with small thread number. size_t tileSize = floor(data.size() / ctx->getNumThreads()); size_t leftovers = data.size() % ctx->getNumThreads(); @@ -272,9 +272,9 @@ class ElementsAttrBuilder { data.begin() + beginOffset, data.begin() + endOffset); }; - auto work = [&](size_t threadNumber) { - auto batch = fetchBatch(threadNumber); - for (WideNum &n : batch) + auto work = [&](size_t threadNumber, bool parallel = true) { + auto tile = fetchBatch(threadNumber, parallel); + for (WideNum &n : tile) n = fun(n); }; // Using 'parallelFor()' introduces large overhead. @@ -282,7 +282,7 @@ class ElementsAttrBuilder { // `minCount`. constexpr size_t minCount = 1000; if (data.size() < minCount) - work(SIZE_MAX); // Sequential + work(0, /*parallel*/ false); else parallelFor(ctx, 0, ctx->getNumThreads(), work); };