Skip to content

Commit

Permalink
Update fetchBatch() for serial exec.
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <[email protected]>
  • Loading branch information
imaihal committed Feb 6, 2025
1 parent b6f25e9 commit 8d17387
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
15 changes: 7 additions & 8 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,15 +888,14 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
ShapedType reducedType = type.clone(reducedShape);
return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
StridesRange<1> sRange(reducedShape, {reducedStrides});
StridesRange<1> axesRange(axesShape, {axesStrides});
SmallVector<std::pair<int64_t, uint64_t>, 4> batch;
for (auto &idxoffs : sRange)
batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0]));

StridesRange<1> axesRange(axesShape, {axesStrides});

auto fetchBatch = [&](size_t threadNumber) {
auto fetchBatch = [&](size_t threadNumber, bool parallel) {
// retrun all data without spliting for sequential execution.
if (threadNumber == SIZE_MAX)
if (!parallel)
return llvm::make_range(batch.begin(), batch.end());
// Each thread fetches the same batch size. The leftovers are set in the
// threads with small thread number.
Expand All @@ -916,10 +915,10 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
batch.begin() + beginOffset, batch.begin() + endOffset);
};

auto work = [&](size_t threadNumber) {
auto batch = fetchBatch(threadNumber);
auto work = [&](size_t threadNumber, bool parallel = true) {
auto tile = fetchBatch(threadNumber, parallel);
// Traverse and populate each element d in dstNums.
for (auto b : batch) {
for (auto b : tile) {
WideNum &d = dstNums[b.first];
int64_t srcPos = b.second;
// Traverse all the elements that reduce together into d.
Expand Down Expand Up @@ -948,7 +947,7 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
constexpr size_t minCount = 2000;
size_t inputCount = batch.size() * axesRange.size();
if (inputCount < minCount)
work(SIZE_MAX); // Sequential
work(0, /*parallel*/ false);
else
parallelFor(ctx, 0, ctx->getNumThreads(), work);
});
Expand Down
14 changes: 7 additions & 7 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ class ElementsAttrBuilder {
mlir::MLIRContext *ctx = disposablePool.getContext();
return [fun = std::move(fun), ctx](
llvm::MutableArrayRef<WideNum> data) -> void {
auto fetchBatch = [&](size_t threadNumber) {
auto fetchBatch = [&](size_t threadNumber, bool parallel) {
// retrun all data without spliting for sequential execution.
if (threadNumber == SIZE_MAX)
if (!parallel)
return llvm::make_range(data.begin(), data.end());
// Each thread fetches the same batch size. The leftovers are set in the
// Each thread fetches the same data size. The leftovers are set in the
// threads with small thread number.
size_t tileSize = floor(data.size() / ctx->getNumThreads());
size_t leftovers = data.size() % ctx->getNumThreads();
Expand All @@ -272,17 +272,17 @@ class ElementsAttrBuilder {
data.begin() + beginOffset, data.begin() + endOffset);
};

auto work = [&](size_t threadNumber) {
auto batch = fetchBatch(threadNumber);
for (WideNum &n : batch)
auto work = [&](size_t threadNumber, bool parallel = true) {
auto tile = fetchBatch(threadNumber, parallel);
for (WideNum &n : tile)
n = fun(n);
};
// Using 'parallelFor()' introduces large overhead.
// To avoid this overhead, call work() directry if input size is less than
// `minCount`.
constexpr size_t minCount = 1000;
if (data.size() < minCount)
work(SIZE_MAX); // Sequential
work(0, /*parallel*/ false);
else
parallelFor(ctx, 0, ctx->getNumThreads(), work);
};
Expand Down

0 comments on commit 8d17387

Please sign in to comment.