Skip to content

Commit

Permalink
[CIR][CIRGen][LowerToLLVM] Support address space casting (llvm#652)
Browse files Browse the repository at this point in the history
* New `CastKind::addrspace_cast` for `cir.cast`
* `TargetCIRGenInfo::performAddrSpaceCast` helper for non-constant
values only
* CIRGen for address space casting of pointers and references
* Lowering to LLVM
  • Loading branch information
seven-mile authored and lanza committed Oct 1, 2024
1 parent 5c0c6f4 commit b13b0bc
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 5 deletions.
11 changes: 11 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Dialect/IR/FPEnv.h"
#include "clang/CIR/MissingFeatures.h"

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -328,6 +329,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createBitcast(src, getPointerTo(newPointeeTy));
}

mlir::Value createAddrSpaceCast(mlir::Location loc, mlir::Value src,
mlir::Type newTy) {
return createCast(loc, mlir::cir::CastKind::address_space, src, newTy);
}

mlir::Value createAddrSpaceCast(mlir::Value src, mlir::Type newTy) {
return createAddrSpaceCast(src.getLoc(), src, newTy);
}

mlir::Value createPtrIsNull(mlir::Value ptr) {
return createNot(createPtrToBoolCast(ptr));
}
Expand Down Expand Up @@ -391,6 +401,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {

// Creates constant nullptr for pointer type ty.
mlir::cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
assert(!MissingFeatures::targetCodeGenInfoGetNullPointer());
return create<mlir::cir::ConstantOp>(loc, ty, getConstPtrAttr(ty, 0));
}

Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ def CK_FloatToBoolean : I32EnumAttrCase<"float_to_bool", 10>;
def CK_BooleanToIntegral : I32EnumAttrCase<"bool_to_int", 11>;
def CK_IntegralToFloat : I32EnumAttrCase<"int_to_float", 12>;
def CK_BooleanToFloat : I32EnumAttrCase<"bool_to_float", 13>;
def CK_AddressSpaceConversion : I32EnumAttrCase<"address_space", 14>;

def CastKind : I32EnumAttr<
"CastKind",
"cast kind",
[CK_IntegralToBoolean, CK_ArrayToPointerDecay, CK_IntegralCast,
CK_BitCast, CK_FloatingCast, CK_PtrToBoolean, CK_FloatToIntegral,
CK_IntegralToPointer, CK_PointerToIntegral, CK_FloatToBoolean,
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat]> {
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat,
CK_AddressSpaceConversion]> {
let cppNamespace = "::mlir::cir";
}

Expand All @@ -98,6 +100,7 @@ def CastOp : CIR_Op<"cast", [Pure]> {
- `ptr_to_bool`
- `bool_to_int`
- `bool_to_float`
- `address_space`

This is effectively a subset of the rules from
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct MissingFeatures {
static bool checkFunctionCallABI() { return false; }
static bool zeroInitializer() { return false; }
static bool targetCodeGenInfoIsProtoCallVariadic() { return false; }
static bool targetCodeGenInfoGetNullPointer() { return false; }
static bool chainCalls() { return false; }
static bool operandBundles() { return false; }
static bool exceptions() { return false; }
Expand Down
11 changes: 10 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "CIRGenModule.h"
#include "CIRGenOpenMPRuntime.h"
#include "CIRGenValue.h"
#include "TargetInfo.h"

#include "clang/AST/ExprCXX.h"
#include "clang/AST/GlobalDecl.h"
Expand Down Expand Up @@ -1953,7 +1954,15 @@ LValue CIRGenFunction::buildCastLValue(const CastExpr *E) {
assert(0 && "NYI");
}
case CK_AddressSpaceConversion: {
assert(0 && "NYI");
LValue LV = buildLValue(E->getSubExpr());
QualType DestTy = getContext().getPointerType(E->getType());
mlir::Value V = getTargetHooks().performAddrSpaceCast(
*this, LV.getPointer(), E->getSubExpr()->getType().getAddressSpace(),
E->getType().getAddressSpace(), ConvertType(DestTy));
assert(!MissingFeatures::tbaa());
return makeAddrLValue(Address(V, getTypes().convertTypeForMem(E->getType()),
LV.getAddress().getAlignment()),
E->getType(), LV.getBaseInfo());
}
case CK_ObjCObjectLValueCast: {
assert(0 && "NYI");
Expand Down
21 changes: 19 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "CIRGenFunction.h"
#include "CIRGenModule.h"
#include "CIRGenOpenMPRuntime.h"
#include "TargetInfo.h"
#include "clang/CIR/MissingFeatures.h"

#include "clang/AST/StmtVisitor.h"
Expand Down Expand Up @@ -1514,8 +1515,24 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return CGF.getBuilder().createBitcast(CGF.getLoc(E->getSourceRange()), Src,
DstTy);
}
case CK_AddressSpaceConversion:
llvm_unreachable("NYI");
case CK_AddressSpaceConversion: {
Expr::EvalResult Result;
if (E->EvaluateAsRValue(Result, CGF.getContext()) &&
Result.Val.isNullPointer()) {
// If E has side effect, it is emitted even if its final result is a
// null pointer. In that case, a DCE pass should be able to
// eliminate the useless instructions emitted during translating E.
if (Result.HasSideEffects) {
llvm_unreachable("NYI");
}
return CGF.CGM.buildNullConstant(DestTy, CGF.getLoc(E->getExprLoc()));
}
// Since target may map different address spaces in AST to the same address
// space, an address space conversion may end up as a bitcast.
return CGF.CGM.getTargetCIRGenInfo().performAddrSpaceCast(
CGF, Visit(E), E->getType()->getPointeeType().getAddressSpace(),
DestTy->getPointeeType().getAddressSpace(), ConvertType(DestTy));
}
case CK_AtomicToNonAtomic:
llvm_unreachable("NYI");
case CK_NonAtomicToAtomic:
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,17 @@ ABIArgInfo X86_64ABIInfo::classifyReturnType(QualType RetTy) const {
return ABIArgInfo::getDirect(ResType);
}

mlir::Value TargetCIRGenInfo::performAddrSpaceCast(
CIRGenFunction &CGF, mlir::Value Src, clang::LangAS SrcAddr,
clang::LangAS DestAddr, mlir::Type DestTy, bool IsNonNull) const {
// Since target may map different address spaces in AST to the same address
// space, an address space conversion may end up as a bitcast.
if (auto globalOp = Src.getDefiningOp<mlir::cir::GlobalOp>())
llvm_unreachable("Global ops addrspace cast NYI");
// Try to preserve the source's name to make IR more readable.
return CGF.getBuilder().createAddrSpaceCast(Src, DestTy);
}

const TargetCIRGenInfo &CIRGenModule::getTargetCIRGenInfo() {
if (TheTargetCIRGenInfo)
return *TheTargetCIRGenInfo;
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace cir {

class CIRGenFunction;
class CIRGenModule;

/// This class organizes various target-specific codegeneration issues, like
/// target-specific attributes, builtins and so on.
Expand Down Expand Up @@ -65,6 +66,18 @@ class TargetCIRGenInfo {
return clang::LangAS::Default;
}

/// Perform address space cast of an expression of pointer type.
/// \param V is the value to be casted to another address space.
/// \param SrcAddr is the language address space of \p V.
/// \param DestAddr is the targeted language address space.
/// \param DestTy is the destination pointer type.
/// \param IsNonNull is the flag indicating \p V is known to be non null.
virtual mlir::Value performAddrSpaceCast(CIRGenFunction &CGF, mlir::Value V,
clang::LangAS SrcAddr,
clang::LangAS DestAddr,
mlir::Type DestTy,
bool IsNonNull = false) const;

virtual ~TargetCIRGenInfo() {}
};

Expand Down
12 changes: 11 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,15 @@ LogicalResult CastOp::verify() {
return emitOpError() << "requires !cir.float type for result";
return success();
}
case cir::CastKind::address_space: {
auto srcPtrTy = srcType.dyn_cast<mlir::cir::PointerType>();
auto resPtrTy = resType.dyn_cast<mlir::cir::PointerType>();
if (!srcPtrTy || !resPtrTy)
return emitOpError() << "requires !cir.ptr type for source and result";
if (srcPtrTy.getPointee() != resPtrTy.getPointee())
return emitOpError() << "requires two types differ in addrspace only";
return success();
}
}

llvm_unreachable("Unknown CastOp kind?");
Expand All @@ -514,7 +523,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return foldResults[0].get<mlir::Attribute>();
return {};
}
case mlir::cir::CastKind::bitcast: {
case mlir::cir::CastKind::bitcast:
case mlir::cir::CastKind::address_space: {
return getSrc();
}
default:
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
mlir::cir::CmpOpKind::ne, castOp.getSrc(), null);
break;
}
case mlir::cir::CastKind::address_space: {
auto dstTy = castOp.getType();
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy = getTypeConverter()->convertType(dstTy);
rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
castOp, llvmDstTy, llvmSrcVal);
break;
}
}

return mlir::success();
Expand Down
57 changes: 57 additions & 0 deletions clang/test/CIR/CodeGen/address-space-conversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM

using pi1_t = int __attribute__((address_space(1))) *;
using pi2_t = int __attribute__((address_space(2))) *;

using ri1_t = int __attribute__((address_space(1))) &;
using ri2_t = int __attribute__((address_space(2))) &;

// CIR: cir.func @{{.*test_ptr.*}}
// LLVM: define void @{{.*test_ptr.*}}
void test_ptr() {
pi1_t ptr1;
pi2_t ptr2 = (pi2_t)ptr1;
// CIR: %[[#PTR1:]] = cir.load %{{[0-9]+}} : !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>, !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: %[[#CAST:]] = cir.cast(address_space, %[[#PTR1]] : !cir.ptr<!s32i, addrspace(1)>), !cir.ptr<!s32i, addrspace(2)>
// CIR-NEXT: cir.store %[[#CAST]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(2)>, !cir.ptr<!cir.ptr<!s32i, addrspace(2)>>

// LLVM: %[[#PTR1:]] = load ptr addrspace(1), ptr %{{[0-9]+}}, align 8
// LLVM-NEXT: %[[#CAST:]] = addrspacecast ptr addrspace(1) %[[#PTR1]] to ptr addrspace(2)
// LLVM-NEXT: store ptr addrspace(2) %[[#CAST]], ptr %{{[0-9]+}}, align 8
}

// CIR: cir.func @{{.*test_ref.*}}
// LLVM: define void @{{.*test_ref.*}}
void test_ref() {
pi1_t ptr;
ri1_t ref1 = *ptr;
ri2_t ref2 = (ri2_t)ref1;
// CIR: %[[#DEREF:]] = cir.load deref %{{[0-9]+}} : !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>, !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: cir.store %[[#DEREF]], %[[#ALLOCAREF1:]] : !cir.ptr<!s32i, addrspace(1)>, !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>
// CIR-NEXT: %[[#REF1:]] = cir.load %[[#ALLOCAREF1]] : !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>, !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: %[[#CAST:]] = cir.cast(address_space, %[[#REF1]] : !cir.ptr<!s32i, addrspace(1)>), !cir.ptr<!s32i, addrspace(2)>
// CIR-NEXT: cir.store %[[#CAST]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(2)>, !cir.ptr<!cir.ptr<!s32i, addrspace(2)>>

// LLVM: %[[#DEREF:]] = load ptr addrspace(1), ptr %{{[0-9]+}}, align 8
// LLVM-NEXT: store ptr addrspace(1) %[[#DEREF]], ptr %[[#ALLOCAREF1:]], align 8
// LLVM-NEXT: %[[#REF1:]] = load ptr addrspace(1), ptr %[[#ALLOCAREF1]], align 8
// LLVM-NEXT: %[[#CAST:]] = addrspacecast ptr addrspace(1) %[[#REF1]] to ptr addrspace(2)
// LLVM-NEXT: store ptr addrspace(2) %[[#CAST]], ptr %{{[0-9]+}}, align 8
}

// CIR: cir.func @{{.*test_nullptr.*}}
// LLVM: define void @{{.*test_nullptr.*}}
void test_nullptr() {
constexpr pi1_t null1 = nullptr;
pi2_t ptr = (pi2_t)null1;
// CIR: %[[#NULL1:]] = cir.const #cir.ptr<null> : !cir.ptr<!s32i, addrspace(1)>
// CIR-NEXT: cir.store %[[#NULL1]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(1)>, !cir.ptr<!cir.ptr<!s32i, addrspace(1)>>
// CIR-NEXT: %[[#NULL2:]] = cir.const #cir.ptr<null> : !cir.ptr<!s32i, addrspace(2)>
// CIR-NEXT: cir.store %[[#NULL2]], %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(2)>, !cir.ptr<!cir.ptr<!s32i, addrspace(2)>>

// LLVM: store ptr addrspace(1) null, ptr %{{[0-9]+}}, align 8
// LLVM-NEXT: store ptr addrspace(2) null, ptr %{{[0-9]+}}, align 8
}
9 changes: 9 additions & 0 deletions clang/test/CIR/IR/cast.cir
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ module {
%2 = cir.cast(bitcast, %p : !cir.ptr<!s32i>), !cir.ptr<f32>
cir.return
}

cir.func @addrspace_cast(%arg0: !cir.ptr<!s32i>) {
%0 = cir.cast(address_space, %arg0 : !cir.ptr<!s32i>), !cir.ptr<!s32i, addrspace(2)>
cir.return
}
}

// CHECK: cir.func @yolo(%arg0: !s32i)
// CHECK: %1 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
// CHECK: %2 = cir.cast(array_to_ptrdecay, %0 : !cir.ptr<!cir.array<!s32i x 10>>), !cir.ptr<!s32i>

// CHECK: cir.func @bitcast
// CHECK: %0 = cir.cast(bitcast, %arg0 : !cir.ptr<!s32i>), !cir.ptr<f32>

// CHECK: cir.func @addrspace_cast
// CHECK: %0 = cir.cast(address_space, %arg0 : !cir.ptr<!s32i>), !cir.ptr<!s32i, addrspace(2)>
25 changes: 25 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,31 @@ cir.func @cast24(%p : !u32i) {

// -----

!u32i = !cir.int<u, 32>
!u64i = !cir.int<u, 64>
cir.func @cast25(%p : !cir.ptr<!u32i, addrspace(1)>) {
%0 = cir.cast(address_space, %p : !cir.ptr<!u32i, addrspace(1)>), !cir.ptr<!u64i, addrspace(2)> // expected-error {{requires two types differ in addrspace only}}
cir.return
}

// -----

!u64i = !cir.int<u, 64>
cir.func @cast26(%p : !cir.ptr<!u64i, addrspace(1)>) {
%0 = cir.cast(address_space, %p : !cir.ptr<!u64i, addrspace(1)>), !u64i // expected-error {{requires !cir.ptr type for source and result}}
cir.return
}

// -----

!u64i = !cir.int<u, 64>
cir.func @cast27(%p : !u64i) {
%0 = cir.cast(address_space, %p : !u64i), !cir.ptr<!u64i, addrspace(1)> // expected-error {{requires !cir.ptr type for source and result}}
cir.return
}

// -----

!u32i = !cir.int<u, 32>
!u8i = !cir.int<u, 8>
module {
Expand Down
9 changes: 9 additions & 0 deletions clang/test/CIR/Transforms/merge-cleanups.cir
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,13 @@ module {
cir.return %0 : !cir.ptr<!s32i>
}

// Should remove redundant address space casts.
// CHECK-LABEL: @addrspacecastfold
// CHECK: %[[ARG0:.+]]: !cir.ptr<!s32i, addrspace(2)>
// CHECK: cir.return %[[ARG0]] : !cir.ptr<!s32i, addrspace(2)>
cir.func @addrspacecastfold(%arg0: !cir.ptr<!s32i, addrspace(2)>) -> !cir.ptr<!s32i, addrspace(2)> {
%0 = cir.cast(address_space, %arg0: !cir.ptr<!s32i, addrspace(2)>), !cir.ptr<!s32i, addrspace(2)>
cir.return %0 : !cir.ptr<!s32i, addrspace(2)>
}

}

0 comments on commit b13b0bc

Please sign in to comment.