Skip to content

Commit

Permalink
[CIR][CUDA] Skeleton of NVPTX target lowering info
Browse files Browse the repository at this point in the history
  • Loading branch information
AdUhTkJm committed Feb 16, 2025
1 parent a0091e3 commit 18c8ec1
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_clang_library(TargetLowering
TargetInfo.cpp
TargetLoweringInfo.cpp
Targets/AArch64.cpp
Targets/NVPTX.cpp
Targets/SPIR.cpp
Targets/X86.cpp
Targets/LoweringPrepareAArch64CXXABI.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ createTargetLoweringInfo(LowerModule &LM) {
}
case llvm::Triple::spirv64:
return createSPIRVTargetLoweringInfo(LM);
case llvm::Triple::nvptx64:
return createNVPTXTargetLoweringInfo(LM);
default:
cir_cconv_unreachable("ABI NYI");
}
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ createAArch64TargetLoweringInfo(LowerModule &CGM, cir::AArch64ABIKind AVXLevel);
std::unique_ptr<TargetLoweringInfo>
createSPIRVTargetLoweringInfo(LowerModule &CGM);

std::unique_ptr<TargetLoweringInfo>
createNVPTXTargetLoweringInfo(LowerModule &CGM);

} // namespace cir

#endif // LLVM_CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_TARGETINFO_H
71 changes: 71 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/NVPTX.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===- NVPTX.cpp - TargetInfo for NVPTX -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "ABIInfoImpl.h"
#include "LowerFunctionInfo.h"
#include "LowerTypes.h"
#include "TargetInfo.h"
#include "TargetLoweringInfo.h"
#include "clang/CIR/ABIArgInfo.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/Support/ErrorHandling.h"

using ABIArgInfo = cir::ABIArgInfo;
using MissingFeature = cir::MissingFeatures;

namespace cir {

//===----------------------------------------------------------------------===//
// NVPTX ABI Implementation
//===----------------------------------------------------------------------===//

namespace {

class NVPTXABIInfo : public ABIInfo {
public:
NVPTXABIInfo(LowerTypes &LT) : ABIInfo(LT) {}

private:
void computeInfo(LowerFunctionInfo &FI) const override {
llvm_unreachable("NYI");
}
};

class NVPTXTargetLoweringInfo : public TargetLoweringInfo {
public:
NVPTXTargetLoweringInfo(LowerTypes &LT)
: TargetLoweringInfo(std::make_unique<NVPTXABIInfo>(LT)) {}

unsigned getTargetAddrSpaceFromCIRAddrSpace(
cir::AddressSpaceAttr addressSpaceAttr) const override {
using Kind = cir::AddressSpaceAttr::Kind;
switch (addressSpaceAttr.getValue()) {
case Kind::offload_private:
return 0;
case Kind::offload_local:
return 3;
case Kind::offload_global:
return 1;
case Kind::offload_constant:
return 2;
case Kind::offload_generic:
return 4;
default:
cir_cconv_unreachable("Unknown CIR address space for this target");
}
}
};

} // namespace

std::unique_ptr<TargetLoweringInfo>
createNVPTXTargetLoweringInfo(LowerModule &lowerModule) {
return std::make_unique<NVPTXTargetLoweringInfo>(lowerModule.getTypes());
}

} // namespace cir
33 changes: 33 additions & 0 deletions clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,34 @@
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=LLVM-HOST --input-file=%t.cir %s

// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.cir %s

// Attribute for global_fn
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}

__host__ void host_fn(int *a, int *b, int *c) {}
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
// CIR-DEVICE-NOT: cir.func @_Z7host_fnPiS_S_
// LLVM-HOST: void @_Z7host_fnPiS_S_
// LLVM-DEVICE-NOT: void @_Z7host_fnPiS_S_

__device__ void device_fn(int* a, double b, float c) {}
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
// CIR-DEVICE: cir.func @_Z9device_fnPidf
// LLVM-HOST-NOT: void @_Z9device_fnPidf
// LLVM-DEVICE: void @_Z9device_fnPidf

__global__ void global_fn(int a) {}
// CIR-DEVICE: @_Z9global_fni
// LLVM-DEVICE: @_Z9global_fni

// Check for device stub emission.

Expand All @@ -32,10 +47,16 @@ __global__ void global_fn(int a) {}
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
// CIR-HOST: cir.call @cudaLaunchKernel

// LLVM-HOST: void @_Z24__device_stub__global_fni
// LLVM-HOST: alloca [1 x ptr], i64 1, align 16
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni, %struct.dim3 %{{[0-9]+}}, %struct.dim3 %{{[0-9]+}}, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, ptr %{{[0-9]+}})

int main() {
global_fn<<<1, 1>>>(1);
}
// CIR-DEVICE-NOT: cir.func @main()
// LLVM-DEVICE-NOT: i32 @main()

// CIR-HOST: cir.func @main()
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
Expand All @@ -46,3 +67,15 @@ int main() {
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
// CIR-HOST: }

// LLVM-HOST: i32 @main()
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
// LLVM-HOST: [[PushLLVM:%[0-9]+]] = call i32 @__cudaPushCallConfiguration
// LLVM-HOST: [[ConfigOKLLVM:%[0-9]+]] = icmp ne i32 [[PushLLVM]], 0
// LLVM-HOST: br i1 [[ConfigOKLLVM]], label %[[Ifso:[0-9]+]], label %[[Ifnot:[0-9]+]]
// LLVM-HOST: [[Ifso]]:
// LLVM-HOST: call void @_Z24__device_stub__global_fni(i32 1)
// LLVM-HOST: br label %[[Ifnot]]
// LLVM-HOST: [[Ifnot]]:
// LLVM-HOST: ret i32

0 comments on commit 18c8ec1

Please sign in to comment.