Skip to content

Commit

Permalink
[CIR] Data member pointer comparison and casts (llvm#1268)
Browse files Browse the repository at this point in the history
This PR adds CIRGen and LLVM lowering support for the following language
features related to pointers to data members:

  - Comparisons between pointers to data members.
  - Casting from pointers to data members to boolean.
  - Reinterpret casts between pointers to data members.
  • Loading branch information
Lancern authored Jan 15, 2025
1 parent 1029b19 commit 0ae5000
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 16 deletions.
3 changes: 2 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def CK_FloatComplexToIntegralComplex
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
def CK_IntegralComplexToFloatComplex
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;
def CK_MemberPtrToBoolean : I32EnumAttrCase<"member_ptr_to_bool", 25>;

def CastKind : I32EnumAttr<
"CastKind",
Expand All @@ -135,7 +136,7 @@ def CastKind : I32EnumAttr<
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
CK_IntegralComplexToFloatComplex]> {
CK_IntegralComplexToFloatComplex, CK_MemberPtrToBoolean]> {
let cppNamespace = "::cir";
}

Expand Down
22 changes: 17 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
};

if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
assert(0 && "not implemented");
assert(E->getOpcode() == BO_EQ || E->getOpcode() == BO_NE);
mlir::Value lhs = CGF.emitScalarExpr(E->getLHS());
mlir::Value rhs = CGF.emitScalarExpr(E->getRHS());
cir::CmpOpKind kind = ClangCmpToCIRCmp(E->getOpcode());
Result =
Builder.createCompare(CGF.getLoc(E->getExprLoc()), kind, lhs, rhs);
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
BinOpInfo BOInfo = emitBinOps(E);
mlir::Value LHS = BOInfo.LHS;
Expand Down Expand Up @@ -1741,8 +1746,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
auto Ty = mlir::cast<cir::DataMemberType>(CGF.convertType(DestTy));
return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc()));
}
case CK_ReinterpretMemberPointer:
llvm_unreachable("NYI");
case CK_ReinterpretMemberPointer: {
mlir::Value src = Visit(E);
return Builder.createBitcast(CGF.getLoc(E->getExprLoc()), src,
CGF.convertType(DestTy));
}
case CK_BaseToDerivedMemberPointer:
case CK_DerivedToBaseMemberPointer: {
mlir::Value src = Visit(E);
Expand Down Expand Up @@ -1875,8 +1883,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return emitPointerToBoolConversion(Visit(E), E->getType());
case CK_FloatingToBoolean:
return emitFloatToBoolConversion(Visit(E), CGF.getLoc(E->getExprLoc()));
case CK_MemberPointerToBoolean:
llvm_unreachable("NYI");
case CK_MemberPointerToBoolean: {
mlir::Value memPtr = Visit(E);
return Builder.createCast(CGF.getLoc(CE->getSourceRange()),
cir::CastKind::member_ptr_to_bool, memPtr,
CGF.convertType(DestTy));
}
case CK_FloatingComplexToReal:
case CK_IntegralComplexToReal:
case CK_FloatingComplexToBoolean:
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ LogicalResult cir::CastOp::verify() {
return success();
}

// Handle the data member pointer types.
if (mlir::isa<cir::DataMemberType>(srcType) &&
mlir::isa<cir::DataMemberType>(resType))
return success();

// This is the only cast kind where we don't want vector types to decay
// into the element type.
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||
Expand Down Expand Up @@ -705,6 +710,13 @@ LogicalResult cir::CastOp::verify() {
<< "requires !cir.complex<!cir.float> type for result";
return success();
}
case cir::CastKind::member_ptr_to_bool: {
if (!mlir::isa<cir::DataMemberType>(srcType))
return emitOpError() << "requires !cir.data_member type for source";
if (!mlir::isa<cir::BoolType>(resType))
return emitOpError() << "requires !cir.bool type for result";
return success();
}
}

llvm_unreachable("Unknown CastOp kind?");
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ class CIRCXXABI {
virtual mlir::Value
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value
lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;
};

/// Creates an Itanium-family ABI.
Expand Down
52 changes: 48 additions & 4 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ class ItaniumCXXABI : public CIRCXXABI {
mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;
};

} // namespace
Expand All @@ -89,18 +101,23 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
return false;
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
static mlir::Type getABITypeForDataMember(LowerModule &lowerMod) {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
const clang::TargetInfo &target = LM.getTarget();
const clang::TargetInfo &target = lowerMod.getTarget();
clang::TargetInfo::IntType ptrdiffTy =
target.getPtrDiffType(clang::LangAS::Default);
return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy),
return cir::IntType::get(lowerMod.getMLIRContext(),
target.getTypeWidth(ptrdiffTy),
target.isTypeSigned(ptrdiffTy));
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
return getABITypeForDataMember(LM);
}

mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const {
Expand Down Expand Up @@ -175,6 +192,33 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
/*isDerivedToBase=*/false, builder);
}

mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const {
return builder.create<cir::CmpOp>(op.getLoc(), op.getKind(), loweredLhs,
loweredRhs);
}

mlir::Value
ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return builder.create<cir::CastOp>(op.getLoc(), loweredDstTy,
cir::CastKind::bitcast, loweredSrc);
}

mlir::Value
ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
// Itanium C++ ABI 2.3:
// A NULL pointer is represented as -1.
auto nullAttr = cir::IntAttr::get(getABITypeForDataMember(LM), -1);
auto nullValue = builder.create<cir::ConstantOp>(op.getLoc(), nullAttr);
return builder.create<cir::CmpOp>(op.getLoc(), cir::CmpOpKind::ne, loweredSrc,
nullValue);
}

CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
switch (LM.getCXXABIKind()) {
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't
Expand Down
34 changes: 31 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
}
case cir::CastKind::bitcast: {
auto dstTy = castOp.getType();
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy = getTypeConverter()->convertType(dstTy);

if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType())) {
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
castOp, llvmDstTy, src, rewriter);
rewriter.replaceOp(castOp, loweredResult);
return mlir::success();
}
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
llvm_unreachable("NYI");

auto llvmSrcVal = adaptor.getOperands().front();
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
llvmSrcVal);
return mlir::success();
Expand All @@ -1324,6 +1334,16 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
llvmSrcVal);
break;
}
case cir::CastKind::member_ptr_to_bool: {
mlir::Value loweredResult;
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
llvm_unreachable("NYI");
else
loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast(
castOp, src, rewriter);
rewriter.replaceOp(castOp, loweredResult);
break;
}
default: {
return castOp.emitError("Unhandled cast kind: ")
<< castOp.getKindAttrName();
Expand Down Expand Up @@ -2902,6 +2922,14 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
auto type = cmpOp.getLhs().getType();

if (mlir::isa<cir::DataMemberType>(type)) {
assert(lowerMod && "lowering module is not available");
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
rewriter.replaceOp(cmpOp, loweredResult);
return mlir::success();
}

// Lower to LLVM comparison op.
// if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
Expand Down Expand Up @@ -4087,6 +4115,7 @@ void populateCIRToLLVMConversionPatterns(
argsVarMap, patterns.getContext());
patterns.add<
// clang-format off
CIRToLLVMCastOpLowering,
CIRToLLVMLoadOpLowering,
CIRToLLVMStoreOpLowering,
CIRToLLVMGlobalOpLowering,
Expand All @@ -4096,14 +4125,14 @@ void populateCIRToLLVMConversionPatterns(
patterns.add<
// clang-format off
CIRToLLVMBaseDataMemberOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMDerivedDataMemberOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering
// clang-format on
>(converter, patterns.getContext(), lowerModule);
patterns.add<
// clang-format off
CIRToLLVMPtrStrideOpLowering,
CIRToLLVMCastOpLowering,
CIRToLLVMInlineAsmOpLowering
// clang-format on
>(converter, patterns.getContext(), dataLayout);
Expand Down Expand Up @@ -4132,7 +4161,6 @@ void populateCIRToLLVMConversionPatterns(
CIRToLLVMCallOpLowering,
CIRToLLVMCatchParamOpLowering,
CIRToLLVMClearCacheOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMCmpThreeWayOpLowering,
CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexImagOpLowering,
Expand Down
15 changes: 12 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,18 @@ class CIRToLLVMBrCondOpLowering
};

class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
cir::LowerModule *lowerMod;
mlir::DataLayout const &dataLayout;

mlir::Type convertTy(mlir::Type ty) const;

public:
CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}
using mlir::OpConversionPattern<cir::CastOp>::OpConversionPattern;
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
dataLayout(dataLayout) {}

mlir::LogicalResult
matchAndRewrite(cir::CastOp op, OpAdaptor,
Expand Down Expand Up @@ -649,8 +651,15 @@ class CIRToLLVMShiftOpLowering
};

class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::CmpOp>::OpConversionPattern;
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::CmpOp op, OpAdaptor,
Expand Down
26 changes: 26 additions & 0 deletions clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,29 @@ auto derived_to_base_zero_offset(int Derived::*ptr) -> int Base1::* {
// LLVM-NEXT: %[[#ret:]] = load i64, ptr %[[#ret_slot]]
// LLVM-NEXT: ret i64 %[[#ret]]
}

struct Foo {
int a;
};

struct Bar {
int a;
};

bool to_bool(int Foo::*x) {
return x;
}

// CIR-LABEL: @_Z7to_boolM3Fooi
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cast(member_ptr_to_bool, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.bool
// CIR: }

auto bitcast(int Foo::*x) {
return reinterpret_cast<int Bar::*>(x);
}

// CIR-LABEL: @_Z7bitcastM3Fooi
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cast(bitcast, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.data_member<!s32i in !ty_Bar>
// CIR: }
44 changes: 44 additions & 0 deletions clang/test/CIR/CodeGen/pointer-to-data-member-cmp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir --check-prefix=CIR %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s

struct Foo {
int a;
};

struct Bar {
int a;
};

bool eq(int Foo::*x, int Foo::*y) {
return x == y;
}

// CIR-LABEL: @_Z2eqM3FooiS0_
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cmp(eq, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
// CIR: }

// LLVM-LABEL: @_Z2eqM3FooiS0_
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %{{.+}} = icmp eq i64 %[[#x]], %[[#y]]
// LLVM: }

bool ne(int Foo::*x, int Foo::*y) {
return x != y;
}

// CIR-LABEL: @_Z2neM3FooiS0_
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cmp(ne, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
// CIR: }

// LLVM-LABEL: @_Z2neM3FooiS0_
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#x]], %[[#y]]
// LLVM: }

0 comments on commit 0ae5000

Please sign in to comment.