Skip to content

Commit

Permalink
[CIR][CodeGen][Lowering] Support Integer overflow with fwrap (llvm#539)
Browse files Browse the repository at this point in the history
This PR fixes some cases when a program compiled with `-fwrapv` fails
with `NYI` .
Basically, the default behavior  is no overlap:
```
void baz(int x, int y) {
  int z = x - y;
}
```
LLVM IR (no CIR enabled):
```
%sub = sub nsw i32 %0, %1
```
and with `-fwrapv` :
```
%sub = sub i32 %0, %1
```
We need something similar in CIR. The only way I see how to implement it
is to add a couple of attributes to the `BinOp` to make things even with
the llvm dialect.

Well, are there any other ideas?

---------

Co-authored-by: Bruno Cardoso Lopes <[email protected]>
  • Loading branch information
2 people authored and lanza committed Oct 1, 2024
1 parent b4bd94b commit 8afc7f3
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 20 deletions.
15 changes: 15 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<mlir::cir::StoreOp>(loc, val, dst, _volatile, order);
}

mlir::Value createSub(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
bool hasNSW = false) {
auto op = create<mlir::cir::BinOp>(lhs.getLoc(), lhs.getType(),
mlir::cir::BinOpKind::Sub, lhs, rhs);
if (hasNUW)
op.setNoUnsignedWrap(true);
if (hasNSW)
op.setNoSignedWrap(true);
return op;
}

mlir::Value createNSWSub(mlir::Value lhs, mlir::Value rhs) {
return createSub(lhs, rhs, false, true);
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
12 changes: 8 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -934,14 +934,18 @@ def BinOp : CIR_Op<"binop", [Pure,
// TODO: get more accurate than CIR_AnyType
let results = (outs CIR_AnyType:$result);
let arguments = (ins Arg<BinOpKind, "binop kind">:$kind,
CIR_AnyType:$lhs, CIR_AnyType:$rhs);
CIR_AnyType:$lhs, CIR_AnyType:$rhs,
UnitAttr:$no_unsigned_wrap,
UnitAttr:$no_signed_wrap);

let assemblyFormat = [{
`(` $kind `,` $lhs `,` $rhs `)` `:` type($lhs) attr-dict
`(` $kind `,` $lhs `,` $rhs `)`
(`nsw` $no_signed_wrap^)?
(`nuw` $no_unsigned_wrap^)?
`:` type($lhs) attr-dict
}];

// Already covered by the traits
let hasVerifier = 0;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
19 changes: 9 additions & 10 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
auto &builder = CGF.getBuilder();
auto amt = builder.getSInt32(amount, loc);
if (CGF.getLangOpts().isSignedOverflowDefined()) {
llvm_unreachable("NYI");
value = builder.create<mlir::cir::PtrStrideOp>(loc, value.getType(),
value, amt);
} else {
value = builder.create<mlir::cir::PtrStrideOp>(loc, value.getType(),
value, amt);
Expand Down Expand Up @@ -1207,7 +1208,8 @@ static mlir::Value buildPointerArithmetic(CIRGenFunction &CGF,

mlir::Type elemTy = CGF.convertTypeForMem(elementType);
if (CGF.getLangOpts().isSignedOverflowDefined())
llvm_unreachable("ptr arithmetic with signed overflow is NYI");
return CGF.getBuilder().create<mlir::cir::PtrStrideOp>(
CGF.getLoc(op.E->getExprLoc()), pointer.getType(), pointer, index);

return CGF.buildCheckedInBoundsGEP(elemTy, pointer, index, isSigned,
isSubtraction, op.E->getExprLoc());
Expand Down Expand Up @@ -1245,20 +1247,17 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
if (Ops.CompType->isSignedIntegerOrEnumerationType()) {
switch (CGF.getLangOpts().getSignedOverflowBehavior()) {
case LangOptions::SOB_Defined: {
llvm_unreachable("NYI");
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS);
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.createSub(Ops.LHS, Ops.RHS);
[[fallthrough]];
}
case LangOptions::SOB_Undefined:
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS);
return Builder.createNSWSub(Ops.LHS, Ops.RHS);
[[fallthrough]];
case LangOptions::SOB_Trapping:
if (CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");
return Builder.createNSWSub(Ops.LHS, Ops.RHS);
llvm_unreachable("NYI");
}
}
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2950,6 +2950,24 @@ LogicalResult AtomicFetch::verify() {
return mlir::success();
}

LogicalResult BinOp::verify() {
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();

if (!isa<mlir::cir::IntType>(getType()) && noWrap)
return emitError()
<< "only operations on integer values may have nsw/nuw flags";

bool noWrapOps = getKind() == mlir::cir::BinOpKind::Add ||
getKind() == mlir::cir::BinOpKind::Sub ||
getKind() == mlir::cir::BinOpKind::Mul;

if (noWrap && !noWrapOps)
return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
"'sub' and 'mul'";

return mlir::success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 18 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,18 @@ class CIRUnaryOpLowering
};

class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {

mlir::LLVM::IntegerOverflowFlags
getIntOverflowFlag(mlir::cir::BinOp op) const {
if (op.getNoUnsignedWrap())
return mlir::LLVM::IntegerOverflowFlags::nuw;

if (op.getNoSignedWrap())
return mlir::LLVM::IntegerOverflowFlags::nsw;

return mlir::LLVM::IntegerOverflowFlags::none;
}

public:
using OpConversionPattern<mlir::cir::BinOp>::OpConversionPattern;

Expand All @@ -1908,19 +1920,22 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
switch (op.getKind()) {
case mlir::cir::BinOpKind::Add:
if (type.isa<mlir::cir::IntType>())
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmTy, lhs, rhs);
break;
case mlir::cir::BinOpKind::Sub:
if (type.isa<mlir::cir::IntType>())
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, llvmTy, lhs, rhs);
break;
case mlir::cir::BinOpKind::Mul:
if (type.isa<mlir::cir::IntType>())
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, llvmTy, lhs, rhs);
break;
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/binop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void b0(int a, int b) {
// CHECK: = cir.binop(div, %6, %7) : !s32i
// CHECK: = cir.binop(rem, %9, %10) : !s32i
// CHECK: = cir.binop(add, %12, %13) : !s32i
// CHECK: = cir.binop(sub, %15, %16) : !s32i
// CHECK: = cir.binop(sub, %15, %16) nsw : !s32i
// CHECK: = cir.shift( right, %18 : !s32i, %19 : !s32i) -> !s32i
// CHECK: = cir.shift(left, %21 : !s32i, %22 : !s32i) -> !s32i
// CHECK: = cir.binop(and, %24, %25) : !s32i
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ int f() {
// CHECK: %1 = cir.call @_Z1pv() : () -> !cir.ptr<!s32i>
// CHECK: %2 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK: %3 = cir.const(#cir.int<22> : !s32i) : !s32i
// CHECK: %4 = cir.binop(sub, %2, %3) : !s32i
// CHECK: %4 = cir.binop(sub, %2, %3) nsw : !s32i
30 changes: 30 additions & 0 deletions clang/test/CIR/CodeGen/int-wrap.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fwrapv -fclangir -emit-cir %s -o - 2>&1 | FileCheck %s --check-prefix=WRAP
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - 2>&1 | FileCheck %s --check-prefix=NOWRAP

#define N 42

typedef struct {
const char* ptr;
} A;

// WRAP: cir.binop(sub, {{.*}}, {{.*}}) : !s32i
// NOWRAP: cir.binop(sub, {{.*}}, {{.*}}) nsw : !s32i
void foo(int* ar, int len) {
int x = ar[len - N];
}

// check that the ptr_stride is generated in both cases (i.e. no NYI fails)

// WRAP: cir.ptr_stride
// NOWRAP: cir.ptr_stride
void bar(A* a, unsigned n) {
a->ptr = a->ptr + n;
}

// WRAP cir.ptr_stride
// NOWRAP: cir.ptr_stride
void baz(A* a) {
a->ptr--;
}


18 changes: 17 additions & 1 deletion clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,22 @@ cir.func @bad_fetch(%x: !cir.ptr<!cir.float>, %y: !cir.float) -> () {

// -----

cir.func @bad_operands_for_nowrap(%x: !cir.float, %y: !cir.float) {
// expected-error@+1 {{only operations on integer values may have nsw/nuw flags}}
%0 = cir.binop(add, %x, %y) nsw : !cir.float
}

// -----

!u32i = !cir.int<u, 32>

cir.func @bad_binop_for_nowrap(%x: !u32i, %y: !u32i) {
// expected-error@+1 {{The nsw/nuw flags are applicable to opcodes: 'add', 'sub' and 'mul'}}
%0 = cir.binop(div, %x, %y) nsw : !u32i
}

// -----

!s32i = !cir.int<s, 32>

module {
Expand All @@ -1046,4 +1062,4 @@ module {
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
cir.return
}
}
}
24 changes: 24 additions & 0 deletions clang/test/CIR/Lowering/int-wrap.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: cir-opt %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR
// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s -check-prefix=LLVM

!s32i = !cir.int<s, 32>
module {
cir.func @test(%arg0: !s32i) {
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["len", init] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, cir.ptr <!s32i>
%1 = cir.load %0 : cir.ptr <!s32i>, !s32i
%2 = cir.const(#cir.int<42> : !s32i) : !s32i
%3 = cir.binop(sub, %1, %2) nsw : !s32i
%4 = cir.binop(sub, %1, %2) nuw : !s32i
%5 = cir.binop(sub, %1, %2) : !s32i
cir.return
}
}

// MLIR: llvm.sub {{.*}}, {{.*}} overflow<nsw> : i32
// MLIR-NEXT: llvm.sub {{.*}}, {{.*}} overflow<nuw> : i32
// MLIR-NEXT: llvm.sub {{.*}}, {{.*}} : i32

// LLVM: sub nsw i32 {{.*}}, {{.*}}, !dbg !9
// LLVM-NEXT: sub nuw i32 {{.*}}, {{.*}}, !dbg !10
// LLVM-NEXT: sub i32 {{.*}}, {{.*}}, !dbg !11

0 comments on commit 8afc7f3

Please sign in to comment.