From 8ab1895706c4be8897722db149667946ca8f3d38 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Sun, 23 Feb 2025 13:03:53 +0800 Subject: [PATCH] [CIR] Add support for comparisons between pointers to member functions The CIRGen support is already there. This patch adds LLVM lowering support for comparisons between pointers to member functions. Note that pointers to member functions could only be compared for equality. --- .../Transforms/TargetLowering/CIRCXXABI.h | 4 ++ .../TargetLowering/ItaniumCXXABI.cpp | 59 +++++++++++++++++++ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 13 +++- .../CIR/CodeGen/pointer-to-member-func.cpp | 40 +++++++++++++ 4 files changed, 113 insertions(+), 3 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index 995fcd027919..c36923a41e35 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -122,6 +122,10 @@ class CIRCXXABI { mlir::Value loweredRhs, mlir::OpBuilder &builder) const = 0; + virtual mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const = 0; + virtual mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, mlir::Value loweredSrc, diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index 992cf88efaea..2819adb25b8c 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -103,6 +103,10 @@ class ItaniumCXXABI : public CIRCXXABI { mlir::Value loweredRhs, mlir::OpBuilder &builder) const override; + mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const override; + mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; @@ -478,6 +482,61 @@ mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op, loweredRhs); } +mlir::Value ItaniumCXXABI::lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const { + assert(op.getKind() == cir::CmpOpKind::eq || + op.getKind() == cir::CmpOpKind::ne); + + cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(LM); + mlir::Value ptrdiffZero = builder.create( + op.getLoc(), ptrdiffCIRTy, cir::IntAttr::get(ptrdiffCIRTy, 0)); + + mlir::Value lhsPtrField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredLhs, 0); + mlir::Value rhsPtrField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredRhs, 0); + mlir::Value ptrCmp = builder.create(op.getLoc(), op.getKind(), + lhsPtrField, rhsPtrField); + mlir::Value ptrCmpToNull = builder.create( + op.getLoc(), op.getKind(), lhsPtrField, ptrdiffZero); + + mlir::Value lhsAdjField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredLhs, 1); + mlir::Value rhsAdjField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredRhs, 1); + mlir::Value adjCmp = builder.create(op.getLoc(), op.getKind(), + lhsAdjField, rhsAdjField); + + // We use cir.select to represent "||" and "&&" operations below: + // - cir.select if %a then %b else false => %a && %b + // - cir.select if %a then true else %b => %a || %b + // TODO: Do we need to invent dedicated "cir.logical_or" and "cir.logical_and" + // operations for this? + auto boolTy = cir::BoolType::get(op.getContext()); + mlir::Value trueValue = builder.create( + op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, true)); + mlir::Value falseValue = builder.create( + op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, false)); + auto create_and = [&](mlir::Value lhs, mlir::Value rhs) { + return builder.create(op.getLoc(), lhs, rhs, falseValue); + }; + auto create_or = [&](mlir::Value lhs, mlir::Value rhs) { + return builder.create(op.getLoc(), lhs, trueValue, rhs); + }; + + mlir::Value result; + if (op.getKind() == cir::CmpOpKind::eq) { + // (lhs.ptr == null || lhs.adj == rhs.adj) && lhs.ptr == rhs.ptr + result = create_and(create_or(ptrCmpToNull, adjCmp), ptrCmp); + } else { + // (lhs.ptr != null && lhs.adj != rhs.adj) || lhs.ptr != rhs.ptr + result = create_or(create_and(ptrCmpToNull, adjCmp), ptrCmp); + } + + return result; +} + mlir::Value ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, mlir::Value loweredSrc, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index e1a2d49fc5d2..12ce2fb6eda6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2893,10 +2893,17 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { auto type = cmpOp.getLhs().getType(); - if (mlir::isa(type)) { + if (mlir::isa(type)) { assert(lowerMod && "lowering module is not available"); - mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp( - cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + + mlir::Value loweredResult; + if (mlir::isa(type)) + loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp( + cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + else + loweredResult = lowerMod->getCXXABI().lowerMethodCmp( + cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + rewriter.replaceOp(cmpOp, loweredResult); return mlir::success(); } diff --git a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp index a1a42f4d494c..5baf9c9bd23a 100644 --- a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp +++ b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp @@ -78,3 +78,43 @@ void call(Foo *obj, void (Foo::*func)(int), int arg) { // LLVM-NEXT: %[[#arg:]] = load i32, ptr %{{.+}} // LLVM-NEXT: call void %[[#callee_ptr]](ptr %[[#adjusted_this]], i32 %[[#arg]]) // LLVM: } + +bool cmp_eq(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { + return lhs == rhs; +} + +// CHECK-LABEL: @_Z6cmp_eqM3FooFviES1_ +// CHECK: %{{.+}} = cir.cmp(eq, %{{.+}}, %{{.+}}) : !cir.method in !ty_Foo>, !cir.bool + +// LLVM-LABEL: @_Z6cmp_eqM3FooFviES1_ +// LLVM: %[[#lhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#rhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#lhs_ptr:]] = extractvalue { i64, i64 } %[[#lhs]], 0 +// LLVM-NEXT: %[[#rhs_ptr:]] = extractvalue { i64, i64 } %[[#rhs]], 0 +// LLVM-NEXT: %[[#ptr_cmp:]] = icmp eq i64 %[[#lhs_ptr]], %[[#rhs_ptr]] +// LLVM-NEXT: %[[#ptr_null:]] = icmp eq i64 %[[#lhs_ptr]], 0 +// LLVM-NEXT: %[[#lhs_adj:]] = extractvalue { i64, i64 } %[[#lhs]], 1 +// LLVM-NEXT: %[[#rhs_adj:]] = extractvalue { i64, i64 } %[[#rhs]], 1 +// LLVM-NEXT: %[[#adj_cmp:]] = icmp eq i64 %[[#lhs_adj]], %[[#rhs_adj]] +// LLVM-NEXT: %[[#tmp:]] = or i1 %[[#ptr_null]], %[[#adj_cmp]] +// LLVM-NEXT: %{{.+}} = and i1 %[[#tmp]], %[[#ptr_cmp]] + +bool cmp_ne(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { + return lhs != rhs; +} + +// CHECK-LABEL: @_Z6cmp_neM3FooFviES1_ +// CHECK: %{{.+}} = cir.cmp(ne, %{{.+}}, %{{.+}}) : !cir.method in !ty_Foo>, !cir.bool + +// LLVM-LABEL: @_Z6cmp_neM3FooFviES1_ +// LLVM: %[[#lhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#rhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#lhs_ptr:]] = extractvalue { i64, i64 } %[[#lhs]], 0 +// LLVM-NEXT: %[[#rhs_ptr:]] = extractvalue { i64, i64 } %[[#rhs]], 0 +// LLVM-NEXT: %[[#ptr_cmp:]] = icmp ne i64 %[[#lhs_ptr]], %[[#rhs_ptr]] +// LLVM-NEXT: %[[#ptr_null:]] = icmp ne i64 %[[#lhs_ptr]], 0 +// LLVM-NEXT: %[[#lhs_adj:]] = extractvalue { i64, i64 } %[[#lhs]], 1 +// LLVM-NEXT: %[[#rhs_adj:]] = extractvalue { i64, i64 } %[[#rhs]], 1 +// LLVM-NEXT: %[[#adj_cmp:]] = icmp ne i64 %[[#lhs_adj]], %[[#rhs_adj]] +// LLVM-NEXT: %[[#tmp:]] = and i1 %[[#ptr_null]], %[[#adj_cmp]] +// LLVM-NEXT: %{{.+}} = or i1 %[[#tmp]], %[[#ptr_cmp]]