Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][ThroughMLIR] Support lowering CastOp to arith #577

Merged
merged 7 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,135 @@ class CIRGetGlobalOpLowering
}
};

static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Value src, mlir::Type dstTy,
bool isSigned = false) {
auto srcTy = src.getType();
assert(isa<mlir::IntegerType>(srcTy));
assert(isa<mlir::IntegerType>(dstTy));

auto srcWidth = srcTy.cast<mlir::IntegerType>().getWidth();
auto dstWidth = dstTy.cast<mlir::IntegerType>().getWidth();
auto loc = src.getLoc();

if (dstWidth > srcWidth && isSigned)
return rewriter.create<mlir::arith::ExtSIOp>(loc, dstTy, src);
else if (dstWidth > srcWidth)
return rewriter.create<mlir::arith::ExtUIOp>(loc, dstTy, src);
else if (dstWidth < srcWidth)
return rewriter.create<mlir::arith::TruncIOp>(loc, dstTy, src);
else
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
}

class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
public:
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;

inline mlir::Type convertTy(mlir::Type ty) const {
return getTypeConverter()->convertType(ty);
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::CastOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (isa<mlir::cir::VectorType>(op.getSrc().getType()))
llvm_unreachable("CastOp lowering for vector type is not supported yet");
auto src = adaptor.getSrc();
auto dstType = op.getResult().getType();
using CIR = mlir::cir::CastKind;
switch (op.getKind()) {
case CIR::int_to_bool: {
auto zero = rewriter.create<mlir::cir::ConstantOp>(
src.getLoc(), op.getSrc().getType(),
mlir::cir::IntAttr::get(op.getSrc().getType(), 0));
rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>(
op, mlir::cir::BoolType::get(getContext()), mlir::cir::CmpOpKind::ne,
op.getSrc(), zero);
return mlir::success();
}
case CIR::integral: {
auto newDstType = convertTy(dstType);
auto srcType = op.getSrc().getType();
mlir::cir::IntType srcIntType = srcType.cast<mlir::cir::IntType>();
auto newOp =
createIntCast(rewriter, src, newDstType, srcIntType.isSigned());
rewriter.replaceOp(op, newOp);
return mlir::success();
}
case CIR::floating: {
auto newDstType = convertTy(dstType);
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();

if (!dstTy.isa<mlir::cir::CIRFPTypeInterface>() ||
!srcTy.isa<mlir::cir::CIRFPTypeInterface>())
return op.emitError() << "NYI cast from " << srcTy << " to " << dstTy;

auto getFloatWidth = [](mlir::Type ty) -> unsigned {
return ty.cast<mlir::cir::CIRFPTypeInterface>().getWidth();
};

if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
rewriter.replaceOpWithNewOp<mlir::arith::TruncFOp>(op, newDstType, src);
else
rewriter.replaceOpWithNewOp<mlir::arith::ExtFOp>(op, newDstType, src);
return mlir::success();
}
case CIR::float_to_bool: {
auto dstTy = op.getType().cast<mlir::cir::BoolType>();
auto newDstType = convertTy(dstTy);
auto kind = mlir::arith::CmpFPredicate::UNE;

// Check if float is not equal to zero.
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));

// Extend comparison result to either bool (C++) or int (C).
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, src, zeroFloat);
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
cmpResult);
return mlir::success();
}
case CIR::bool_to_int: {
auto dstTy = op.getType().cast<mlir::cir::IntType>();
auto newDstType = convertTy(dstTy).cast<mlir::IntegerType>();
auto newOp = createIntCast(rewriter, src, newDstType);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
case CIR::bool_to_float: {
auto dstTy = op.getType();
auto newDstType = convertTy(dstTy);
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
return mlir::success();
}
case CIR::int_to_float: {
auto dstTy = op.getType();
auto newDstType = convertTy(dstTy);
if (op.getSrc().getType().cast<mlir::cir::IntType>().isSigned())
rewriter.replaceOpWithNewOp<mlir::arith::SIToFPOp>(op, newDstType, src);
else
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
return mlir::success();
}
case CIR::float_to_int: {
auto dstTy = op.getType();
auto newDstType = convertTy(dstTy);
if (op.getResult().getType().cast<mlir::cir::IntType>().isSigned())
rewriter.replaceOpWithNewOp<mlir::arith::FPToSIOp>(op, newDstType, src);
else
rewriter.replaceOpWithNewOp<mlir::arith::FPToUIOp>(op, newDstType, src);
return mlir::success();
}
default:
break;
}
return mlir::failure();
}
};

void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
Expand All @@ -718,7 +847,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering>(converter, patterns.getContext());
CIRGetGlobalOpLowering, CIRCastOpLowering>(
converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
147 changes: 147 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/cast.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM

!s32i = !cir.int<s, 32>
!s16i = !cir.int<s, 16>
!u32i = !cir.int<u, 32>
!u16i = !cir.int<u, 16>
!u8i = !cir.int<u, 8>
module {
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i8
// LLVM-LABEL: define i8 @cast_int_to_bool(i32 %0)
cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool {
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32
// MLIR-NEXT: arith.cmpi ne, %arg0, %[[ZERO]]
// LLVM-NEXT: icmp ne i32 %0, 0

%1 = cir.cast(int_to_bool, %i : !u32i), !cir.bool
cir.return %1 : !cir.bool
}
// MLIR-LABEL: func.func @cast_integral_trunc(%arg0: i32) -> i16
// LLVM-LABEL: define i16 @cast_integral_trunc(i32 %0)
cir.func @cast_integral_trunc(%i : !u32i) -> !u16i {
// MLIR-NEXT: arith.trunci %arg0 : i32 to i16
// LLVM-NEXT: trunc i32 %0 to i16

%1 = cir.cast(integral, %i : !u32i), !u16i
cir.return %1 : !u16i
}
// MLIR-LABEL: func.func @cast_integral_extu(%arg0: i16) -> i32
// LLVM-LABEL: define i32 @cast_integral_extu(i16 %0)
cir.func @cast_integral_extu(%i : !u16i) -> !u32i {
// MLIR-NEXT: arith.extui %arg0 : i16 to i32
// LLVM-NEXT: zext i16 %0 to i32

%1 = cir.cast(integral, %i : !u16i), !u32i
cir.return %1 : !u32i
}
// MLIR-LABEL: func.func @cast_integral_exts(%arg0: i16) -> i32
// LLVM-LABEL: define i32 @cast_integral_exts(i16 %0)
cir.func @cast_integral_exts(%i : !s16i) -> !s32i {
// MLIR-NEXT: arith.extsi %arg0 : i16 to i32
// LLVM-NEXT: sext i16 %0 to i32

%1 = cir.cast(integral, %i : !s16i), !s32i
cir.return %1 : !s32i
}
// MLIR-LABEL: func.func @cast_integral_same_size(%arg0: i32) -> i32
// LLVM-LABEL: define i32 @cast_integral_same_size(i32 %0)
cir.func @cast_integral_same_size(%i : !u32i) -> !s32i {
// MLIR-NEXT: %0 = arith.bitcast %arg0 : i32 to i32
// LLVM-NEXT: ret i32 %0

%1 = cir.cast(integral, %i : !u32i), !s32i
cir.return %1 : !s32i
}
// MLIR-LABEL: func.func @cast_floating_trunc(%arg0: f64) -> f32
// LLVM-LABEL: define float @cast_floating_trunc(double %0)
cir.func @cast_floating_trunc(%d : !cir.double) -> !cir.float {
// MLIR-NEXT: arith.truncf %arg0 : f64 to f32
// LLVM-NEXT: fptrunc double %0 to float

%1 = cir.cast(floating, %d : !cir.double), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_floating_extf(%arg0: f32) -> f64
// LLVM-LABEL: define double @cast_floating_extf(float %0)
cir.func @cast_floating_extf(%f : !cir.float) -> !cir.double {
// MLIR-NEXT: arith.extf %arg0 : f32 to f64
// LLVM-NEXT: fpext float %0 to double

%1 = cir.cast(floating, %f : !cir.float), !cir.double
cir.return %1 : !cir.double
}
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i8
// LLVM-LABEL: define i8 @cast_float_to_bool(float %0)
cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool {
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// MLIR-NEXT: arith.cmpf une, %arg0, %[[ZERO]] : f32
// LLVM-NEXT: fcmp une float %0, 0.000000e+00

%1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool
cir.return %1 : !cir.bool
}
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i8) -> i8
// LLVM-LABEL: define i8 @cast_bool_to_int8(i8 %0)
cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i {
// MLIR-NEXT: arith.bitcast %arg0 : i8 to i8
// LLVM-NEXT: ret i8 %0

%1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i
cir.return %1 : !u8i
}
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i8) -> i32
// LLVM-LABEL: define i32 @cast_bool_to_int(i8 %0)
cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i {
// MLIR-NEXT: arith.extui %arg0 : i8 to i32
// LLVM-NEXT: zext i8 %0 to i32

%1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i
cir.return %1 : !u32i
}
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i8) -> f32
// LLVM-LABEL: define float @cast_bool_to_float(i8 %0)
cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float {
// MLIR-NEXT: arith.uitofp %arg0 : i8 to f32
// LLVM-NEXT: uitofp i8 %0 to float

%1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_signed_int_to_float(%arg0: i32) -> f32
// LLVM-LABEL: define float @cast_signed_int_to_float(i32 %0)
cir.func @cast_signed_int_to_float(%i : !s32i) -> !cir.float {
// MLIR-NEXT: arith.sitofp %arg0 : i32 to f32
// LLVM-NEXT: sitofp i32 %0 to float

%1 = cir.cast(int_to_float, %i : !s32i), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_unsigned_int_to_float(%arg0: i32) -> f32
// LLVM-LABEL: define float @cast_unsigned_int_to_float(i32 %0)
cir.func @cast_unsigned_int_to_float(%i : !u32i) -> !cir.float {
// MLIR-NEXT: arith.uitofp %arg0 : i32 to f32
// LLVM-NEXT: uitofp i32 %0 to float

%1 = cir.cast(int_to_float, %i : !u32i), !cir.float
cir.return %1 : !cir.float
}
// MLIR-LABEL: func.func @cast_float_to_int_signed(%arg0: f32) -> i32
// LLVM-LABEL: define i32 @cast_float_to_int_signed(float %0)
cir.func @cast_float_to_int_signed(%f : !cir.float) -> !s32i {
// MLIR-NEXT: arith.fptosi %arg0 : f32 to i32
// LLVM-NEXT: fptosi float %0 to i32

%1 = cir.cast(float_to_int, %f : !cir.float), !s32i
cir.return %1 : !s32i
}
// MLIR-LABEL: func.func @cast_float_to_int_unsigned(%arg0: f32) -> i32
// LLVM-LABEL: define i32 @cast_float_to_int_unsigned(float %0)
cir.func @cast_float_to_int_unsigned(%f : !cir.float) -> !u32i {
// MLIR-NEXT: arith.fptoui %arg0 : f32 to i32
// LLVM-NEXT: fptoui float %0 to i32

%1 = cir.cast(float_to_int, %f : !cir.float), !u32i
cir.return %1 : !u32i
}
}
Loading