Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen][LowerToLLVM] Support address space casting #652

Merged
merged 14 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Dialect/IR/FPEnv.h"
#include "clang/CIR/MissingFeatures.h"

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -328,6 +329,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createBitcast(src, getPointerTo(newPointeeTy));
}

mlir::Value createAddrSpaceCast(mlir::Location loc, mlir::Value src,
mlir::Type newTy) {
return createCast(loc, mlir::cir::CastKind::address_space, src, newTy);
}

mlir::Value createAddrSpaceCast(mlir::Value src, mlir::Type newTy) {
return createAddrSpaceCast(src.getLoc(), src, newTy);
}

mlir::Value createPtrIsNull(mlir::Value ptr) {
return createNot(createPtrToBoolCast(ptr));
}
Expand Down Expand Up @@ -391,6 +401,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {

// Creates constant nullptr for pointer type ty.
mlir::cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
assert(!MissingFeatures::targetCodeGenInfoGetNullPointer());
return create<mlir::cir::ConstantOp>(loc, ty, getConstPtrAttr(ty, 0));
}

Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ def CK_FloatToBoolean : I32EnumAttrCase<"float_to_bool", 10>;
def CK_BooleanToIntegral : I32EnumAttrCase<"bool_to_int", 11>;
def CK_IntegralToFloat : I32EnumAttrCase<"int_to_float", 12>;
def CK_BooleanToFloat : I32EnumAttrCase<"bool_to_float", 13>;
def CK_AddressSpaceConversion : I32EnumAttrCase<"address_space", 14>;

def CastKind : I32EnumAttr<
"CastKind",
"cast kind",
[CK_IntegralToBoolean, CK_ArrayToPointerDecay, CK_IntegralCast,
CK_BitCast, CK_FloatingCast, CK_PtrToBoolean, CK_FloatToIntegral,
CK_IntegralToPointer, CK_PointerToIntegral, CK_FloatToBoolean,
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat]> {
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat,
CK_AddressSpaceConversion]> {
let cppNamespace = "::mlir::cir";
}

Expand All @@ -98,6 +100,7 @@ def CastOp : CIR_Op<"cast", [Pure]> {
- `ptr_to_bool`
- `bool_to_int`
- `bool_to_float`
- `address_space`

This is effectively a subset of the rules from
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct MissingFeatures {
static bool checkFunctionCallABI() { return false; }
static bool zeroInitializer() { return false; }
static bool targetCodeGenInfoIsProtoCallVariadic() { return false; }
static bool targetCodeGenInfoGetNullPointer() { return false; }
static bool chainCalls() { return false; }
static bool operandBundles() { return false; }
static bool exceptions() { return false; }
Expand Down
11 changes: 10 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "CIRGenModule.h"
#include "CIRGenOpenMPRuntime.h"
#include "CIRGenValue.h"
#include "TargetInfo.h"

#include "clang/AST/ExprCXX.h"
#include "clang/AST/GlobalDecl.h"
Expand Down Expand Up @@ -1764,7 +1765,15 @@ LValue CIRGenFunction::buildCastLValue(const CastExpr *E) {
assert(0 && "NYI");
}
case CK_AddressSpaceConversion: {
assert(0 && "NYI");
LValue LV = buildLValue(E->getSubExpr());
QualType DestTy = getContext().getPointerType(E->getType());
mlir::Value V = getTargetHooks().performAddrSpaceCast(
*this, LV.getPointer(), E->getSubExpr()->getType().getAddressSpace(),
E->getType().getAddressSpace(), ConvertType(DestTy));
assert(!MissingFeatures::tbaa());
return makeAddrLValue(Address(V, getTypes().convertTypeForMem(E->getType()),
LV.getAddress().getAlignment()),
E->getType(), LV.getBaseInfo());
}
case CK_ObjCObjectLValueCast: {
assert(0 && "NYI");
Expand Down
21 changes: 19 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "CIRGenFunction.h"
#include "CIRGenModule.h"
#include "CIRGenOpenMPRuntime.h"
#include "TargetInfo.h"
#include "clang/CIR/MissingFeatures.h"

#include "clang/AST/StmtVisitor.h"
Expand Down Expand Up @@ -1510,8 +1511,24 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return CGF.getBuilder().createBitcast(CGF.getLoc(E->getSourceRange()), Src,
DstTy);
}
case CK_AddressSpaceConversion:
llvm_unreachable("NYI");
case CK_AddressSpaceConversion: {
Expr::EvalResult Result;
if (E->EvaluateAsRValue(Result, CGF.getContext()) &&
Result.Val.isNullPointer()) {
// If E has side effect, it is emitted even if its final result is a
// null pointer. In that case, a DCE pass should be able to
// eliminate the useless instructions emitted during translating E.
if (Result.HasSideEffects) {
llvm_unreachable("NYI");
}
return CGF.CGM.buildNullConstant(DestTy, CGF.getLoc(E->getExprLoc()));
}
// Since target may map different address spaces in AST to the same address
// space, an address space conversion may end up as a bitcast.
return CGF.CGM.getTargetCIRGenInfo().performAddrSpaceCast(
CGF, Visit(E), E->getType()->getPointeeType().getAddressSpace(),
DestTy->getPointeeType().getAddressSpace(), ConvertType(DestTy));
}
case CK_AtomicToNonAtomic:
llvm_unreachable("NYI");
case CK_NonAtomicToAtomic:
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,17 @@ ABIArgInfo X86_64ABIInfo::classifyReturnType(QualType RetTy) const {
return ABIArgInfo::getDirect(ResType);
}

mlir::Value TargetCIRGenInfo::performAddrSpaceCast(
CIRGenFunction &CGF, mlir::Value Src, clang::LangAS SrcAddr,
clang::LangAS DestAddr, mlir::Type DestTy, bool IsNonNull) const {
// Since target may map different address spaces in AST to the same address
// space, an address space conversion may end up as a bitcast.
if (auto globalOp = Src.getDefiningOp<mlir::cir::GlobalOp>())
llvm_unreachable("Global ops addrspace cast NYI");
// Try to preserve the source's name to make IR more readable.
return CGF.getBuilder().createAddrSpaceCast(Src, DestTy);
}

const TargetCIRGenInfo &CIRGenModule::getTargetCIRGenInfo() {
if (TheTargetCIRGenInfo)
return *TheTargetCIRGenInfo;
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace cir {

class CIRGenFunction;
class CIRGenModule;

/// This class organizes various target-specific codegeneration issues, like
/// target-specific attributes, builtins and so on.
Expand Down Expand Up @@ -65,6 +66,18 @@ class TargetCIRGenInfo {
return clang::LangAS::Default;
}

/// Perform address space cast of an expression of pointer type.
/// \param V is the value to be casted to another address space.
/// \param SrcAddr is the language address space of \p V.
/// \param DestAddr is the targeted language address space.
/// \param DestTy is the destination pointer type.
/// \param IsNonNull is the flag indicating \p V is known to be non null.
virtual mlir::Value performAddrSpaceCast(CIRGenFunction &CGF, mlir::Value V,
clang::LangAS SrcAddr,
clang::LangAS DestAddr,
mlir::Type DestTy,
bool IsNonNull = false) const;

virtual ~TargetCIRGenInfo() {}
};

Expand Down
12 changes: 11 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,15 @@ LogicalResult CastOp::verify() {
return emitOpError() << "requires !cir.float type for result";
return success();
}
case cir::CastKind::address_space: {
auto srcPtrTy = srcType.dyn_cast<mlir::cir::PointerType>();
auto resPtrTy = resType.dyn_cast<mlir::cir::PointerType>();
if (!srcPtrTy || !resPtrTy)
return emitOpError() << "requires !cir.ptr type for source and result";
if (srcPtrTy.getPointee() != resPtrTy.getPointee())
return emitOpError() << "requires two types differ in addrspace only";
return success();
}
}

llvm_unreachable("Unknown CastOp kind?");
Expand All @@ -514,7 +523,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return foldResults[0].get<mlir::Attribute>();
return {};
}
case mlir::cir::CastKind::bitcast: {
case mlir::cir::CastKind::bitcast:
case mlir::cir::CastKind::address_space: {
return getSrc();
}
default:
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
mlir::cir::CmpOpKind::ne, castOp.getSrc(), null);
break;
}
case mlir::cir::CastKind::address_space: {
auto dstTy = castOp.getType();
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy = getTypeConverter()->convertType(dstTy);
rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
castOp, llvmDstTy, llvmSrcVal);
break;
}
}

return mlir::success();
Expand Down
57 changes: 57 additions & 0 deletions clang/test/CIR/CodeGen/address-space-conversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -S -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM

using pi1_t = int __attribute__((address_space(1))) *;
using pi2_t = int __attribute__((address_space(2))) *;

using ri1_t = int __attribute__((address_space(1))) &;
using ri2_t = int __attribute__((address_space(2))) &;

// CIR: cir.func @{{.*test_ptr.*}}
// LLVM: define void @{{.*test_ptr.*}}
void test_ptr() {
pi1_t ptr1;
pi2_t ptr2 = (pi2_t)ptr1;
// CIR: %[[#PTR1:]] = cir.load %{{[0-9]+}} : !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>, !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: %[[#CAST:]] = cir.cast(address_space, %[[#PTR1]] : !cir.ptr<!s32i, addrspace(1)>), !cir.ptr<!s32i, addrspace(2)>
// CIR-NEXT: cir.store %[[#CAST]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(2)>, !cir.ptr<!cir.ptr<!s32i, addrspace(2)>>

// LLVM: %[[#PTR1:]] = load ptr addrspace(1), ptr %{{[0-9]+}}, align 8
// LLVM-NEXT: %[[#CAST:]] = addrspacecast ptr addrspace(1) %[[#PTR1]] to ptr addrspace(2)
// LLVM-NEXT: store ptr addrspace(2) %[[#CAST]], ptr %{{[0-9]+}}, align 8
}

// CIR: cir.func @{{.*test_ref.*}}
// LLVM: define void @{{.*test_ref.*}}
void test_ref() {
pi1_t ptr;
ri1_t ref1 = *ptr;
ri2_t ref2 = (ri2_t)ref1;
// CIR: %[[#DEREF:]] = cir.load deref %{{[0-9]+}} : !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>, !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: cir.store %[[#DEREF]], %[[#ALLOCAREF1:]] : !cir.ptr<!s32i, addrspace(1)>, !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>
// CIR-NEXT: %[[#REF1:]] = cir.load %[[#ALLOCAREF1]] : !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>, !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: %[[#CAST:]] = cir.cast(address_space, %[[#REF1]] : !cir.ptr<!s32i, addrspace(1)>), !cir.ptr<!s32i, addrspace(2)>
// CIR-NEXT: cir.store %[[#CAST]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(2)>, !cir.ptr<!cir.ptr<!s32i, addrspace(2)>>

// LLVM: %[[#DEREF:]] = load ptr addrspace(1), ptr %{{[0-9]+}}, align 8
// LLVM-NEXT: store ptr addrspace(1) %[[#DEREF]], ptr %[[#ALLOCAREF1:]], align 8
// LLVM-NEXT: %[[#REF1:]] = load ptr addrspace(1), ptr %[[#ALLOCAREF1]], align 8
// LLVM-NEXT: %[[#CAST:]] = addrspacecast ptr addrspace(1) %[[#REF1]] to ptr addrspace(2)
// LLVM-NEXT: store ptr addrspace(2) %[[#CAST]], ptr %{{[0-9]+}}, align 8
}

// CIR: cir.func @{{.*test_nullptr.*}}
// LLVM: define void @{{.*test_nullptr.*}}
void test_nullptr() {
constexpr pi1_t null1 = nullptr;
pi2_t ptr = (pi2_t)null1;
// CIR: %[[#NULL1:]] = cir.const #cir.ptr<null> : !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: cir.store %[[#NULL1]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(1)>, !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>
// CIR-NEXT: %[[#NULL2:]] = cir.const #cir.ptr<null> : !cir.ptr<!s32i, addrspace(2)>
// CIR-NEXT: cir.store %[[#NULL2]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(2)>, !cir.ptr<!cir.ptr<!s32i, addrspace(2)>>

// LLVM: store ptr addrspace(1) null, ptr %{{[0-9]+}}, align 8
// LLVM-NEXT: store ptr addrspace(2) null, ptr %{{[0-9]+}}, align 8
}
9 changes: 9 additions & 0 deletions clang/test/CIR/IR/cast.cir
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ module {
%2 = cir.cast(bitcast, %p : !cir.ptr<!s32i>), !cir.ptr<f32>
cir.return
}

cir.func @addrspace_cast(%arg0: !cir.ptr<!s32i>) {
%0 = cir.cast(address_space, %arg0 : !cir.ptr<!s32i>), !cir.ptr<!s32i, addrspace(2)>
cir.return
}
}

// CHECK: cir.func @yolo(%arg0: !s32i)
// CHECK: %1 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
// CHECK: %2 = cir.cast(array_to_ptrdecay, %0 : !cir.ptr<!cir.array<!s32i x 10>>), !cir.ptr<!s32i>

// CHECK: cir.func @bitcast
// CHECK: %0 = cir.cast(bitcast, %arg0 : !cir.ptr<!s32i>), !cir.ptr<f32>

// CHECK: cir.func @addrspace_cast
// CHECK: %0 = cir.cast(address_space, %arg0 : !cir.ptr<!s32i>), !cir.ptr<!s32i, addrspace(2)>
25 changes: 25 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,31 @@ cir.func @cast24(%p : !u32i) {

// -----

!u32i = !cir.int<u, 32>
!u64i = !cir.int<u, 64>
cir.func @cast25(%p : !cir.ptr<!u32i, addrspace(1)>) {
%0 = cir.cast(address_space, %p : !cir.ptr<!u32i, addrspace(1)>), !cir.ptr<!u64i, addrspace(2)> // expected-error {{requires two types differ in addrspace only}}
cir.return
}

// -----

!u64i = !cir.int<u, 64>
cir.func @cast26(%p : !cir.ptr<!u64i, addrspace(1)>) {
%0 = cir.cast(address_space, %p : !cir.ptr<!u64i, addrspace(1)>), !u64i // expected-error {{requires !cir.ptr type for source and result}}
cir.return
}

// -----

!u64i = !cir.int<u, 64>
cir.func @cast27(%p : !u64i) {
%0 = cir.cast(address_space, %p : !u64i), !cir.ptr<!u64i, addrspace(1)> // expected-error {{requires !cir.ptr type for source and result}}
cir.return
}

// -----

!u32i = !cir.int<u, 32>
!u8i = !cir.int<u, 8>
module {
Expand Down
9 changes: 9 additions & 0 deletions clang/test/CIR/Transforms/merge-cleanups.cir
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,13 @@ module {
cir.return %0 : !cir.ptr<!s32i>
}

// Should remove redundant address space casts.
// CHECK-LABEL: @addrspacecastfold
// CHECK: %[[ARG0:.+]]: !cir.ptr<!s32i, addrspace(2)>
// CHECK: cir.return %[[ARG0]] : !cir.ptr<!s32i, addrspace(2)>
cir.func @addrspacecastfold(%arg0: !cir.ptr<!s32i, addrspace(2)>) -> !cir.ptr<!s32i, addrspace(2)> {
%0 = cir.cast(address_space, %arg0: !cir.ptr<!s32i, addrspace(2)>), !cir.ptr<!s32i, addrspace(2)>
cir.return %0 : !cir.ptr<!s32i, addrspace(2)>
}

}
Loading