Skip to content

Commit

Permalink
[CIR][ThroughMLIR] Support lowering CastOp to arith (#577)
Browse files Browse the repository at this point in the history
This commit introduce CIRCastOpLowering for lowering to arith.
  • Loading branch information
ShivaChen authored May 3, 2024
1 parent b361bbe commit 1efee91
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 1 deletion.
132 changes: 131 additions & 1 deletion clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,135 @@ class CIRGetGlobalOpLowering
}
};

static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Value src, mlir::Type dstTy,
bool isSigned = false) {
auto srcTy = src.getType();
assert(isa<mlir::IntegerType>(srcTy));
assert(isa<mlir::IntegerType>(dstTy));

auto srcWidth = srcTy.cast<mlir::IntegerType>().getWidth();
auto dstWidth = dstTy.cast<mlir::IntegerType>().getWidth();
auto loc = src.getLoc();

if (dstWidth > srcWidth && isSigned)
return rewriter.create<mlir::arith::ExtSIOp>(loc, dstTy, src);
else if (dstWidth > srcWidth)
return rewriter.create<mlir::arith::ExtUIOp>(loc, dstTy, src);
else if (dstWidth < srcWidth)
return rewriter.create<mlir::arith::TruncIOp>(loc, dstTy, src);
else
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
}

class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
public:
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;

inline mlir::Type convertTy(mlir::Type ty) const {
return getTypeConverter()->convertType(ty);
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::CastOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (isa<mlir::cir::VectorType>(op.getSrc().getType()))
llvm_unreachable("CastOp lowering for vector type is not supported yet");
auto src = adaptor.getSrc();
auto dstType = op.getResult().getType();
using CIR = mlir::cir::CastKind;
switch (op.getKind()) {
case CIR::int_to_bool: {
auto zero = rewriter.create<mlir::cir::ConstantOp>(
src.getLoc(), op.getSrc().getType(),
mlir::cir::IntAttr::get(op.getSrc().getType(), 0));
rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>(
op, mlir::cir::BoolType::get(getContext()), mlir::cir::CmpOpKind::ne,
op.getSrc(), zero);
return mlir::success();
}
case CIR::integral: {
auto newDstType = convertTy(dstType);
auto srcType = op.getSrc().getType();
mlir::cir::IntType srcIntType = srcType.cast<mlir::cir::IntType>();
auto newOp =
createIntCast(rewriter, src, newDstType, srcIntType.isSigned());
rewriter.replaceOp(op, newOp);
return mlir::success();
}
case CIR::floating: {
auto newDstType = convertTy(dstType);
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();

if (!dstTy.isa<mlir::cir::CIRFPTypeInterface>() ||
!srcTy.isa<mlir::cir::CIRFPTypeInterface>())
return op.emitError() << "NYI cast from " << srcTy << " to " << dstTy;

auto getFloatWidth = [](mlir::Type ty) -> unsigned {
return ty.cast<mlir::cir::CIRFPTypeInterface>().getWidth();
};

if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
rewriter.replaceOpWithNewOp<mlir::arith::TruncFOp>(op, newDstType, src);
else
rewriter.replaceOpWithNewOp<mlir::arith::ExtFOp>(op, newDstType, src);
return mlir::success();
}
case CIR::float_to_bool: {
auto dstTy = op.getType().cast<mlir::cir::BoolType>();
auto newDstType = convertTy(dstTy);
auto kind = mlir::arith::CmpFPredicate::UNE;

// Check if float is not equal to zero.
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));

// Extend comparison result to either bool (C++) or int (C).
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, src, zeroFloat);
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
cmpResult);
return mlir::success();
}
case CIR::bool_to_int: {
auto dstTy = op.getType().cast<mlir::cir::IntType>();
auto newDstType = convertTy(dstTy).cast<mlir::IntegerType>();
auto newOp = createIntCast(rewriter, src, newDstType);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
case CIR::bool_to_float: {
auto dstTy = op.getType();
auto newDstType = convertTy(dstTy);
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
return mlir::success();
}
case CIR::int_to_float: {
auto dstTy = op.getType();
auto newDstType = convertTy(dstTy);
if (op.getSrc().getType().cast<mlir::cir::IntType>().isSigned())
rewriter.replaceOpWithNewOp<mlir::arith::SIToFPOp>(op, newDstType, src);
else
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
return mlir::success();
}
case CIR::float_to_int: {
auto dstTy = op.getType();
auto newDstType = convertTy(dstTy);
if (op.getResult().getType().cast<mlir::cir::IntType>().isSigned())
rewriter.replaceOpWithNewOp<mlir::arith::FPToSIOp>(op, newDstType, src);
else
rewriter.replaceOpWithNewOp<mlir::arith::FPToUIOp>(op, newDstType, src);
return mlir::success();
}
default:
break;
}
return mlir::failure();
}
};

void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
Expand All @@ -718,7 +847,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering>(converter, patterns.getContext());
CIRGetGlobalOpLowering, CIRCastOpLowering>(
converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
147 changes: 147 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/cast.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM

!s32i = !cir.int<s, 32>
!s16i = !cir.int<s, 16>
!u32i = !cir.int<u, 32>
!u16i = !cir.int<u, 16>
!u8i = !cir.int<u, 8>
module {
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i8
// LLVM-LABEL: define i8 @cast_int_to_bool(i32 %0)
cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool {
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32
// MLIR-NEXT: arith.cmpi ne, %arg0, %[[ZERO]]
// LLVM-NEXT: icmp ne i32 %0, 0

%1 = cir.cast(int_to_bool, %i : !u32i), !cir.bool
cir.return %1 : !cir.bool
}
// MLIR-LABEL: func.func @cast_integral_trunc(%arg0: i32) -> i16
// LLVM-LABEL: define i16 @cast_integral_trunc(i32 %0)
cir.func @cast_integral_trunc(%i : !u32i) -> !u16i {
// MLIR-NEXT: arith.trunci %arg0 : i32 to i16
// LLVM-NEXT: trunc i32 %0 to i16

%1 = cir.cast(integral, %i : !u32i), !u16i
cir.return %1 : !u16i
}
// MLIR-LABEL: func.func @cast_integral_extu(%arg0: i16) -> i32
// LLVM-LABEL: define i32 @cast_integral_extu(i16 %0)
cir.func @cast_integral_extu(%i : !u16i) -> !u32i {
// MLIR-NEXT: arith.extui %arg0 : i16 to i32
// LLVM-NEXT: zext i16 %0 to i32

%1 = cir.cast(integral, %i : !u16i), !u32i
cir.return %1 : !u32i
}
// MLIR-LABEL: func.func @cast_integral_exts(%arg0: i16) -> i32
// LLVM-LABEL: define i32 @cast_integral_exts(i16 %0)
cir.func @cast_integral_exts(%i : !s16i) -> !s32i {
// MLIR-NEXT: arith.extsi %arg0 : i16 to i32
// LLVM-NEXT: sext i16 %0 to i32

%1 = cir.cast(integral, %i : !s16i), !s32i
cir.return %1 : !s32i
}
// MLIR-LABEL: func.func @cast_integral_same_size(%arg0: i32) -> i32
// LLVM-LABEL: define i32 @cast_integral_same_size(i32 %0)
cir.func @cast_integral_same_size(%i : !u32i) -> !s32i {
// MLIR-NEXT: %0 = arith.bitcast %arg0 : i32 to i32
// LLVM-NEXT: ret i32 %0

%1 = cir.cast(integral, %i : !u32i), !s32i
cir.return %1 : !s32i
}
// MLIR-LABEL: func.func @cast_floating_trunc(%arg0: f64) -> f32
// LLVM-LABEL: define float @cast_floating_trunc(double %0)
cir.func @cast_floating_trunc(%d : !cir.double) -> !cir.float {
// MLIR-NEXT: arith.truncf %arg0 : f64 to f32
// LLVM-NEXT: fptrunc double %0 to float

%1 = cir.cast(floating, %d : !cir.double), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_floating_extf(%arg0: f32) -> f64
// LLVM-LABEL: define double @cast_floating_extf(float %0)
cir.func @cast_floating_extf(%f : !cir.float) -> !cir.double {
// MLIR-NEXT: arith.extf %arg0 : f32 to f64
// LLVM-NEXT: fpext float %0 to double

%1 = cir.cast(floating, %f : !cir.float), !cir.double
cir.return %1 : !cir.double
}
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i8
// LLVM-LABEL: define i8 @cast_float_to_bool(float %0)
cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool {
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// MLIR-NEXT: arith.cmpf une, %arg0, %[[ZERO]] : f32
// LLVM-NEXT: fcmp une float %0, 0.000000e+00

%1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool
cir.return %1 : !cir.bool
}
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i8) -> i8
// LLVM-LABEL: define i8 @cast_bool_to_int8(i8 %0)
cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i {
// MLIR-NEXT: arith.bitcast %arg0 : i8 to i8
// LLVM-NEXT: ret i8 %0

%1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i
cir.return %1 : !u8i
}
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i8) -> i32
// LLVM-LABEL: define i32 @cast_bool_to_int(i8 %0)
cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i {
// MLIR-NEXT: arith.extui %arg0 : i8 to i32
// LLVM-NEXT: zext i8 %0 to i32

%1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i
cir.return %1 : !u32i
}
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i8) -> f32
// LLVM-LABEL: define float @cast_bool_to_float(i8 %0)
cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float {
// MLIR-NEXT: arith.uitofp %arg0 : i8 to f32
// LLVM-NEXT: uitofp i8 %0 to float

%1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_signed_int_to_float(%arg0: i32) -> f32
// LLVM-LABEL: define float @cast_signed_int_to_float(i32 %0)
cir.func @cast_signed_int_to_float(%i : !s32i) -> !cir.float {
// MLIR-NEXT: arith.sitofp %arg0 : i32 to f32
// LLVM-NEXT: sitofp i32 %0 to float

%1 = cir.cast(int_to_float, %i : !s32i), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_unsigned_int_to_float(%arg0: i32) -> f32
// LLVM-LABEL: define float @cast_unsigned_int_to_float(i32 %0)
cir.func @cast_unsigned_int_to_float(%i : !u32i) -> !cir.float {
// MLIR-NEXT: arith.uitofp %arg0 : i32 to f32
// LLVM-NEXT: uitofp i32 %0 to float

%1 = cir.cast(int_to_float, %i : !u32i), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_float_to_int_signed(%arg0: f32) -> i32
// LLVM-LABEL: define i32 @cast_float_to_int_signed(float %0)
cir.func @cast_float_to_int_signed(%f : !cir.float) -> !s32i {
// MLIR-NEXT: arith.fptosi %arg0 : f32 to i32
// LLVM-NEXT: fptosi float %0 to i32

%1 = cir.cast(float_to_int, %f : !cir.float), !s32i
cir.return %1 : !s32i
}
// MLIR-LABEL: func.func @cast_float_to_int_unsigned(%arg0: f32) -> i32
// LLVM-LABEL: define i32 @cast_float_to_int_unsigned(float %0)
cir.func @cast_float_to_int_unsigned(%f : !cir.float) -> !u32i {
// MLIR-NEXT: arith.fptoui %arg0 : f32 to i32
// LLVM-NEXT: fptoui float %0 to i32

%1 = cir.cast(float_to_int, %f : !cir.float), !u32i
cir.return %1 : !u32i
}
}

0 comments on commit 1efee91

Please sign in to comment.