Skip to content

Commit

Permalink
[SCFToCalyx] Wrap blocks with scf::ExecuteRegion when creating the ne…
Browse files Browse the repository at this point in the history
…w scf::Parallel (#8098)
  • Loading branch information
jiahanxie353 authored Jan 20, 2025
1 parent a9a7a75 commit 6cfac8c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
62 changes: 46 additions & 16 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
/// SCF
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
scf::ParallelOp, scf::ReduceOp,
scf::ExecuteRegionOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp, memref::GetGlobalOp,
Expand Down Expand Up @@ -389,6 +390,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
scf::ReduceOp reduceOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
scf::ParallelOp parallelOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
scf::ExecuteRegionOp executeRegionOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
Expand Down Expand Up @@ -1470,6 +1473,15 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult
BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::ExecuteRegionOp executeRegionOp) const {
// Simply return success because the only remaining `scf.execute_region` op
// are generated by the `BuildParGroups` pass - the rest of them are inlined
// by the `InlineExecuteRegionOpPattern`.
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CallOp callOp) const {
std::string instanceName = calyx::getInstanceName(callOp);
Expand Down Expand Up @@ -1902,8 +1914,8 @@ class BuildParGroups : public calyx::FuncOpPartialLoweringPattern {
auto loc = newParOp.getLoc();
rewriter.insert(newParOp);
OpBuilder insideBuilder(newParOp);
Block *currBlock = nullptr;
auto &region = newParOp.getRegion();
auto *newParBodyBlock = &region.emplaceBlock();

// extract lower bounds, upper bounds, and steps as integer index values
SmallVector<int64_t> lbVals, ubVals, stepVals;
Expand All @@ -1929,25 +1941,32 @@ class BuildParGroups : public calyx::FuncOpPartialLoweringPattern {
SmallVector<int64_t> indices = lbVals;

while (true) {
insideBuilder.setInsertionPointToEnd(newParBodyBlock);
// Create an `scf.execute_region` to wrap each unrolled block since
// `scf.parallel` requires only one block in the body region.
auto execRegionOp =
insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
auto &execRegion = execRegionOp.getRegion();
Block *execBlock = &execRegion.emplaceBlock();
OpBuilder regionBuilder(execRegionOp);
// Each iteration starts with a fresh mapping, so each new block’s
// argument of a region-based operation (such as `scf.for`) get re-mapped
// independently.
IRMapping operandMap;

// Create a new block in the region for the current combination of indices
currBlock = &region.emplaceBlock();
insideBuilder.setInsertionPointToEnd(currBlock);

regionBuilder.setInsertionPointToEnd(execBlock);
// Map induction variables to constant indices
for (unsigned i = 0; i < indices.size(); ++i) {
Value ivConstant =
insideBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
operandMap.map(parOpIVs[i], ivConstant);
}

for (auto it = body->begin(); it != std::prev(body->end()); ++it)
insideBuilder.clone(*it, operandMap);
regionBuilder.clone(*it, operandMap);

// A terminator should always be inserted in `scf.execute_region`'s block.
regionBuilder.create<scf::ReduceOp>(loc);
// Increment indices using `step`
bool done = false;
for (int dim = indices.size() - 1; dim >= 0; --dim) {
Expand Down Expand Up @@ -2043,15 +2062,26 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
} else if (auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
auto parOp = parSchedPtr->parOp;
auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
for (auto &innerBlock : parOp.getRegion().getBlocks()) {
rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
auto seqOp = rewriter.create<calyx::SeqOp>(parOp.getLoc());
rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
if (LogicalResult res = scheduleBasicBlock(
rewriter, path, seqOp.getBodyBlock(), &innerBlock);
res.failed())
return res;
}

WalkResult walkResult =
parOp.walk([&](scf::ExecuteRegionOp execRegion) {
rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
auto seqOp = rewriter.create<calyx::SeqOp>(execRegion.getLoc());
rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());

for (auto &execBlock : execRegion.getRegion().getBlocks()) {
if (LogicalResult res = scheduleBasicBlock(
rewriter, path, seqOp.getBodyBlock(), &execBlock);
res.failed()) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});

if (walkResult.wasInterrupted())
return failure();

} else if (auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
forSchedPtr) {
auto forOp = forSchedPtr->forOp;
Expand Down
24 changes: 12 additions & 12 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ module {
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb1_0 {
// CHECK-DAG: calyx.group @bb0_2 {
// CHECK-DAG: calyx.assign %std_slice_5.in = %c4_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_5.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
Expand All @@ -351,15 +351,15 @@ module {
// CHECK-DAG: calyx.assign %load_1_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_1_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb1_1 {
// CHECK-DAG: calyx.group @bb0_3 {
// CHECK-DAG: calyx.assign %std_slice_4.in = %c1_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_4.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_1_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb2_0 {
// CHECK-DAG: calyx.group @bb0_4 {
// CHECK-DAG: calyx.assign %std_slice_3.in = %c2_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_3.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
Expand All @@ -368,15 +368,15 @@ module {
// CHECK-DAG: calyx.assign %load_2_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_2_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb2_1 {
// CHECK-DAG: calyx.group @bb0_5 {
// CHECK-DAG: calyx.assign %std_slice_2.in = %c4_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_2.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_2_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb3_0 {
// CHECK-DAG: calyx.group @bb0_6 {
// CHECK-DAG: calyx.assign %std_slice_1.in = %c6_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_1.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
Expand All @@ -385,7 +385,7 @@ module {
// CHECK-DAG: calyx.assign %load_3_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_3_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb3_1 {
// CHECK-DAG: calyx.group @bb0_7 {
// CHECK-DAG: calyx.assign %std_slice_0.in = %c5_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_0.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_3_reg.out : i32
Expand All @@ -402,16 +402,16 @@ module {
// CHECK-DAG: calyx.enable @bb0_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb1_0
// CHECK-DAG: calyx.enable @bb1_1
// CHECK-DAG: calyx.enable @bb0_2
// CHECK-DAG: calyx.enable @bb0_3
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb2_0
// CHECK-DAG: calyx.enable @bb2_1
// CHECK-DAG: calyx.enable @bb0_4
// CHECK-DAG: calyx.enable @bb0_5
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb3_0
// CHECK-DAG: calyx.enable @bb3_1
// CHECK-DAG: calyx.enable @bb0_6
// CHECK-DAG: calyx.enable @bb0_7
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: }
Expand Down

0 comments on commit 6cfac8c

Please sign in to comment.