Skip to content

Commit

Permalink
[SwitchToIf] Empty yielded result (#8087)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 authored Jan 15, 2025
1 parent 09790a7 commit 2d87d21
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
11 changes: 8 additions & 3 deletions lib/Transforms/IndexSwitchToIf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
auto loc = switchOp.getLoc();

Region &defaultRegion = switchOp.getDefaultRegion();
bool hasResults = !switchOp.getResultTypes().empty();

Value finalResult;
scf::IfOp prevIfOp = nullptr;
Expand Down Expand Up @@ -66,17 +67,21 @@ struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
ifOp.getElseRegion().end());
}

if (prevIfOp) {
if (prevIfOp && hasResults) {
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
}

if (i == 0)
if (i == 0 && hasResults)
finalResult = ifOp.getResult(0);

prevIfOp = ifOp;
}

rewriter.replaceOp(switchOp, finalResult);
if (hasResults)
rewriter.replaceOp(switchOp, finalResult);
else
rewriter.eraseOp(switchOp);

return success();
}
Expand Down
48 changes: 48 additions & 0 deletions test/Transforms/switch-to-if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,51 @@ module {
return %0 : i32
}
}

// Switch to nested if-else when the yielded result is empty

// -----

module {
// CHECK-LABEL: func.func @main(
// CHECK-SAME: %[[VAL_0:.*]]: index,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2xi32>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_6:.*]] = arith.cmpi eq, %[[VAL_4]], %[[VAL_5]] : index
// CHECK: scf.if %[[VAL_6]] {
// CHECK: %[[VAL_7:.*]] = arith.constant 10 : i32
// CHECK: memref.store %[[VAL_7]], %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref<2xi32>
// CHECK: } else {
// CHECK: %[[VAL_8:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_4]], %[[VAL_8]] : index
// CHECK: scf.if %[[VAL_9]] {
// CHECK: %[[VAL_10:.*]] = arith.constant 20 : i32
// CHECK: memref.store %[[VAL_10]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<2xi32>
// CHECK: } else {
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK: }
func.func @main(%arg0 : index, %arg1 : memref<2xi32>, %arg2 : memref<2xi32>) {
%one = arith.constant 1 : index
%cond = arith.addi %one, %arg0 : index
scf.index_switch %cond
case 2 {
%1 = arith.constant 10 : i32
memref.store %1, %arg1[%one] : memref<2xi32>
scf.yield
}
case 5 {
%2 = arith.constant 20 : i32
memref.store %2, %arg2[%one] : memref<2xi32>
scf.yield
}
default {
scf.yield
}
return
}
}

0 comments on commit 2d87d21

Please sign in to comment.