From 324e4c2c049f36f2ad9d9ac6a8a14a32999675d1 Mon Sep 17 00:00:00 2001 From: AdUhTkJm <2292398666@qq.com> Date: Wed, 12 Feb 2025 11:56:39 +0000 Subject: [PATCH] [CIR][CUDA] Generate device stubs --- clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp | 171 ++++++++++++++++++++ clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h | 47 ++++++ clang/lib/CIR/CodeGen/CIRGenFunction.cpp | 2 +- clang/lib/CIR/CodeGen/CIRGenModule.cpp | 4 +- clang/lib/CIR/CodeGen/CIRGenModule.h | 12 +- clang/lib/CIR/CodeGen/CMakeLists.txt | 1 + clang/test/CIR/CodeGen/CUDA/simple.cu | 26 +-- 7 files changed, 248 insertions(+), 15 deletions(-) create mode 100644 clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp create mode 100644 clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp new file mode 100644 index 000000000000..400c41cbb0d4 --- /dev/null +++ b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp @@ -0,0 +1,171 @@ +//===--- CIRGenCUDARuntime.cpp - Interface to CUDA Runtimes ----*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides an abstract class for CUDA CIR generation. Concrete +// subclasses of this implement code generation for specific OpenCL +// runtime libraries. +// +//===----------------------------------------------------------------------===// + +#include "CIRGenCUDARuntime.h" +#include "CIRGenFunction.h" +#include "clang/Basic/Cuda.h" +#include "clang/CIR/Dialect/IR/CIRTypes.h" + +using namespace clang; +using namespace clang::CIRGen; + +CIRGenCUDARuntime::~CIRGenCUDARuntime() {} + +void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf, + cir::FuncOp fn, + FunctionArgList &args) { + llvm_unreachable("NYI"); +} + +void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf, + cir::FuncOp fn, + FunctionArgList &args) { + if (cgm.getLangOpts().HIP) + llvm_unreachable("NYI"); + + // This requires arguments to be sent to kernels in a different way. + if (cgm.getLangOpts().OffloadViaLLVM) + llvm_unreachable("NYI"); + + auto &builder = cgm.getBuilder(); + + // For cudaLaunchKernel, we must add another layer of indirection + // to arguments. For example, for function `add(int a, float b)`, + // we need to pass it as `void *args[2] = { &a, &b }`. + + auto loc = fn.getLoc(); + auto voidPtrArrayTy = + cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size()); + mlir::Value kernelArgs = builder.createAlloca( + loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args", + CharUnits::fromQuantity(16)); + + // Store arguments into kernelArgs + for (auto [i, arg] : llvm::enumerate(args)) { + mlir::Value index = + builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i)); + mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index); + builder.CIRBaseBuilderTy::createStore( + loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos); + } + + // We retrieve dim3 type by looking into the second argument of + // cudaLaunchKernel, as is done in OG. + TranslationUnitDecl *tuDecl = cgm.getASTContext().getTranslationUnitDecl(); + DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl); + + // The default stream is usually stream 0 (the legacy default stream). + // For per-thread default stream, we need a different LaunchKernel function. + if (cgm.getLangOpts().GPUDefaultStream == + LangOptions::GPUDefaultStreamKind::PerThread) + llvm_unreachable("NYI"); + + std::string launchAPI = "cudaLaunchKernel"; + const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI); + FunctionDecl *launchFD = nullptr; + for (auto *result : dc->lookup(&launchII)) { + if (FunctionDecl *fd = dyn_cast(result)) + launchFD = fd; + } + + if (launchFD == nullptr) { + cgm.Error(cgf.CurFuncDecl->getLocation(), + "Can't find declaration for " + launchAPI); + return; + } + + // Use this function to retrieve arguments for cudaLaunchKernel: + // int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t + // *sharedMem, cudaStream_t *stream) + // + // Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel, + // is a pointer to some opaque struct. + + mlir::Type dim3Ty = + cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType()); + mlir::Type streamTy = + cgf.getTypes().convertType(launchFD->getParamDecl(5)->getType()); + + mlir::Value gridDim = + builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty, + "grid_dim", CharUnits::fromQuantity(8)); + mlir::Value blockDim = + builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty, + "block_dim", CharUnits::fromQuantity(8)); + mlir::Value sharedMem = + builder.createAlloca(loc, cir::PointerType::get(cgm.SizeTy), cgm.SizeTy, + "shared_mem", cgm.getSizeAlign()); + mlir::Value stream = + builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy, + "stream", cgm.getPointerAlign()); + + cir::FuncOp popConfig = cgm.createRuntimeFunction( + cir::FuncType::get({gridDim.getType(), blockDim.getType(), + sharedMem.getType(), stream.getType()}, + cgm.SInt32Ty), + "__cudaPopCallConfiguration"); + cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream}); + + // Now emit the call to cudaLaunchKernel + // cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim, + // void **args, size_t sharedMem, + // cudaStream_t stream); + auto kernelTy = + cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType()); + + mlir::Value kernel = + builder.create(loc, kernelTy, fn.getSymName()); + mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy); + CallArgList launchArgs; + + mlir::Value kernelArgsDecayed = + builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs, + cir::PointerType::get(cgm.VoidPtrTy)); + + launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType()); + launchArgs.add( + RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))), + launchFD->getParamDecl(1)->getType()); + launchArgs.add( + RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))), + launchFD->getParamDecl(2)->getType()); + launchArgs.add(RValue::get(kernelArgsDecayed), + launchFD->getParamDecl(3)->getType()); + launchArgs.add( + RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)), + launchFD->getParamDecl(4)->getType()); + launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType()); + + mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType()); + mlir::Operation *launchFn = + cgm.createRuntimeFunction(cast(launchTy), launchAPI); + const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD); + cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(), + launchArgs); +} + +void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn, + FunctionArgList &args) { + // Device stub and its handle might be different. + if (cgm.getLangOpts().HIP) + llvm_unreachable("NYI"); + + // CUDA 9.0 changed the way to launch kernels. + if (CudaFeatureEnabled(cgm.getTarget().getSDKVersion(), + CudaFeature::CUDA_USES_NEW_LAUNCH) || + cgm.getLangOpts().OffloadViaLLVM) + emitDeviceStubBodyNew(cgf, fn, args); + else + emitDeviceStubBodyLegacy(cgf, fn, args); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h new file mode 100644 index 000000000000..a3145a0baeb3 --- /dev/null +++ b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h @@ -0,0 +1,47 @@ +//===------ CIRGenCUDARuntime.h - Interface to CUDA Runtimes -----*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides an abstract class for CUDA CIR generation. Concrete +// subclasses of this implement code generation for specific OpenCL +// runtime libraries. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H +#define LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H + +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/IR/CIROpsEnums.h" + +namespace clang::CIRGen { + +class CIRGenFunction; +class CIRGenModule; +class FunctionArgList; + +class CIRGenCUDARuntime { +protected: + CIRGenModule &cgm; + +private: + void emitDeviceStubBodyLegacy(CIRGenFunction &cgf, cir::FuncOp fn, + FunctionArgList &args); + void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn, + FunctionArgList &args); + +public: + CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {} + virtual ~CIRGenCUDARuntime(); + + virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn, + FunctionArgList &args); +}; + +} // namespace clang::CIRGen + +#endif // LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index 7de4866cd004..ee9ebaa61b32 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn, emitConstructorBody(Args); else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice && FD->hasAttr()) - llvm_unreachable("NYI"); + CGM.getCUDARuntime().emitDeviceStub(*this, Fn, Args); else if (isa(FD) && cast(FD)->isLambdaStaticInvoker()) { // The lambda static invoker function is special, because it forwards or diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index 76221be12319..74eddeaf6af9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -9,6 +9,7 @@ // This is the internal per-translation-unit state used for CIR translation. // //===----------------------------------------------------------------------===// +#include "CIRGenCUDARuntime.h" #include "CIRGenCXXABI.h" #include "CIRGenCstEmitter.h" #include "CIRGenFunction.h" @@ -108,7 +109,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext, theModule{mlir::ModuleOp::create(builder.getUnknownLoc())}, Diags(Diags), target(astContext.getTargetInfo()), ABI(createCXXABI(*this)), genTypes{*this}, VTables{*this}, - openMPRuntime(new CIRGenOpenMPRuntime(*this)) { + openMPRuntime(new CIRGenOpenMPRuntime(*this)), + cudaRuntime(new CIRGenCUDARuntime(*this)) { // Initialize CIR signed integer types cache. SInt8Ty = cir::IntType::get(&getMLIRContext(), 8, /*isSigned=*/true); diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.h b/clang/lib/CIR/CodeGen/CIRGenModule.h index 867dee754862..eee92e290e53 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.h +++ b/clang/lib/CIR/CodeGen/CIRGenModule.h @@ -15,6 +15,7 @@ #include "Address.h" #include "CIRGenBuilder.h" +#include "CIRGenCUDARuntime.h" #include "CIRGenCall.h" #include "CIRGenOpenCLRuntime.h" #include "CIRGenTBAA.h" @@ -113,6 +114,9 @@ class CIRGenModule : public CIRGenTypeCache { /// Holds the OpenMP runtime std::unique_ptr openMPRuntime; + /// Holds the CUDA runtime + std::unique_ptr cudaRuntime; + /// Per-function codegen information. Updated everytime emitCIR is called /// for FunctionDecls's. CIRGenFunction *CurCGF = nullptr; @@ -862,12 +866,18 @@ class CIRGenModule : public CIRGenTypeCache { /// Print out an error that codegen doesn't support the specified decl yet. void ErrorUnsupported(const Decl *D, const char *Type); - /// Return a reference to the configured OpenMP runtime. + /// Return a reference to the configured OpenCL runtime. CIRGenOpenCLRuntime &getOpenCLRuntime() { assert(openCLRuntime != nullptr); return *openCLRuntime; } + /// Return a reference to the configured CUDA runtime. + CIRGenCUDARuntime &getCUDARuntime() { + assert(cudaRuntime != nullptr); + return *cudaRuntime; + } + void createOpenCLRuntime() { openCLRuntime.reset(new CIRGenOpenCLRuntime(*this)); } diff --git a/clang/lib/CIR/CodeGen/CMakeLists.txt b/clang/lib/CIR/CodeGen/CMakeLists.txt index 02ac813ef732..8a065191f4d1 100644 --- a/clang/lib/CIR/CodeGen/CMakeLists.txt +++ b/clang/lib/CIR/CodeGen/CMakeLists.txt @@ -19,6 +19,7 @@ add_clang_library(clangCIR CIRGenClass.cpp CIRGenCleanup.cpp CIRGenCoroutine.cpp + CIRGenCUDARuntime.cpp CIRGenDecl.cpp CIRGenDeclCXX.cpp CIRGenException.cpp diff --git a/clang/test/CIR/CodeGen/CUDA/simple.cu b/clang/test/CIR/CodeGen/CUDA/simple.cu index 1a822d9bcc88..9675de3fe61a 100644 --- a/clang/test/CIR/CodeGen/CUDA/simple.cu +++ b/clang/test/CIR/CodeGen/CUDA/simple.cu @@ -1,15 +1,17 @@ #include "../Inputs/cuda.h" // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \ -// RUN: -x cuda -emit-cir %s -o %t.cir +// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \ +// RUN: %s -o %t.cir // RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s // RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \ -// RUN: -fcuda-is-device -emit-cir %s -o %t.cir +// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \ +// RUN: %s -o %t.cir // RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s // Attribute for global_fn -// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}} +// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}} __host__ void host_fn(int *a, int *b, int *c) {} // CIR-HOST: cir.func @_Z7host_fnPiS_S_ @@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {} // CIR-HOST-NOT: cir.func @_Z9device_fnPidf // CIR-DEVICE: cir.func @_Z9device_fnPidf -#ifdef __CUDA_ARCH__ -__global__ void global_fn() {} -#else -__global__ void global_fn(); -#endif -// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]]) -// CIR-DEVICE: @_Z9global_fnv +__global__ void global_fn(int a) {} +// CIR-DEVICE: @_Z9global_fni -// Make sure `global_fn` indeed gets emitted -__host__ void x() { auto v = global_fn; } +// Check for device stub emission. + +// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]]) +// CIR-HOST: cir.alloca {{.*}}"kernel_args" +// CIR-HOST: cir.call @__cudaPopCallConfiguration +// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni +// CIR-HOST: cir.call @cudaLaunchKernel