Skip to content

Commit

Permalink
[CIR][CUDA] Generate device stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
AdUhTkJm committed Feb 11, 2025
1 parent 637f2f3 commit 3ad22db
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 13 deletions.
151 changes: 151 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCUDA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
//===---- CIRGenCUDA.cpp - CUDA-specific logic for CIR generation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This contains code dealing with CUDA-specific logic of CIR generation.
//
//===----------------------------------------------------------------------===//

#include "CIRGenFunction.h"
#include "clang/Basic/Cuda.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"

using namespace clang;
using namespace clang::CIRGen;

void CIRGenFunction::emitCUDADeviceStubBody(cir::FuncOp fn,
FunctionArgList &args) {
// CUDA 9.0 changed the way to launch kernels.
if (!CudaFeatureEnabled(CGM.getTarget().getSDKVersion(),
CudaFeature::CUDA_USES_NEW_LAUNCH))
llvm_unreachable("NYI");

// This requires arguments to be sent to kernels in a different way.
if (getLangOpts().OffloadViaLLVM)
llvm_unreachable("NYI");

if (getLangOpts().HIP)
llvm_unreachable("NYI");

auto &builder = CGM.getBuilder();

// For cudaLaunchKernel, we must add another layer of indirection
// to arguments. For example, for function `add(int a, float b)`,
// we need to pass it as `void *args[2] = { &a, &b }`.

auto loc = fn.getLoc();
auto voidPtrArrayTy =
cir::ArrayType::get(&getMLIRContext(), CGM.VoidPtrTy, args.size());
mlir::Value kernelArgs = builder.createAlloca(
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
CharUnits::fromQuantity(16));

// Store arguments into kernelArgs
for (auto [i, arg] : llvm::enumerate(args)) {
mlir::Value index =
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
builder.CIRBaseBuilderTy::createStore(
loc, GetAddrOfLocalVar(arg).getPointer(), storePos);
}

// We retrieve dim3 type by looking into the second argument of
// cudaLaunchKernel, as is done in OG.
TranslationUnitDecl *tuDecl = getContext().getTranslationUnitDecl();
DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl);

// The default stream is usually stream 0 (the legacy default stream).
// For per-thread default stream, we need a different LaunchKernel function.
if (getLangOpts().GPUDefaultStream ==
LangOptions::GPUDefaultStreamKind::PerThread)
llvm_unreachable("NYI");

std::string launchAPI = "cudaLaunchKernel";
const IdentifierInfo &launchII = getContext().Idents.get(launchAPI);
FunctionDecl *launchFD = nullptr;
for (auto *result : dc->lookup(&launchII)) {
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
launchFD = fd;
}

if (launchFD == nullptr) {
CGM.Error(CurFuncDecl->getLocation(),
"Can't find declaration for " + launchAPI);
return;
}

// Use this function to retrieve arguments for cudaLaunchKernel:
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
// *sharedMem, cudaStream_t *stream)
//
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
// is a pointer to some opaque struct.

mlir::Type dim3Ty =
getTypes().convertType(launchFD->getParamDecl(1)->getType());
mlir::Type streamTy =
getTypes().convertType(launchFD->getParamDecl(5)->getType());

mlir::Value gridDim =
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
"grid_dim", CharUnits::fromQuantity(8));
mlir::Value blockDim =
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
"block_dim", CharUnits::fromQuantity(8));
mlir::Value sharedMem =
builder.createAlloca(loc, cir::PointerType::get(CGM.SizeTy), CGM.SizeTy,
"shared_mem", CGM.getSizeAlign());
mlir::Value stream =
builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
"stream", CGM.getPointerAlign());

cir::FuncOp popConfig = CGM.createRuntimeFunction(
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
sharedMem.getType(), stream.getType()},
CGM.SInt32Ty),
"__cudaPopCallConfiguration");
emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});

// Now emit the call to cudaLaunchKernel
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
// void **args, size_t sharedMem,
// cudaStream_t stream);
auto kernelTy =
cir::PointerType::get(&getMLIRContext(), fn.getFunctionType());

mlir::Value kernel =
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
mlir::Value func = builder.createBitcast(kernel, CGM.VoidPtrTy);
CallArgList launchArgs;

mlir::Value kernelArgsDecayed =
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
cir::PointerType::get(CGM.VoidPtrTy));

launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
launchArgs.add(
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
launchFD->getParamDecl(1)->getType());
launchArgs.add(
RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))),
launchFD->getParamDecl(2)->getType());
launchArgs.add(RValue::get(kernelArgsDecayed),
launchFD->getParamDecl(3)->getType());
launchArgs.add(
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
launchFD->getParamDecl(4)->getType());
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());

mlir::Type launchTy = CGM.getTypes().convertType(launchFD->getType());
mlir::Operation *launchFn =
CGM.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
const auto &callInfo = CGM.getTypes().arrangeFunctionDeclaration(launchFD);
emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
launchArgs);
}
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn,
emitConstructorBody(Args);
else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice &&
FD->hasAttr<CUDAGlobalAttr>())
llvm_unreachable("NYI");
emitCUDADeviceStubBody(Fn, Args);
else if (isa<CXXMethodDecl>(FD) &&
cast<CXXMethodDecl>(FD)->isLambdaStaticInvoker()) {
// The lambda static invoker function is special, because it forwards or
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,9 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::LogicalResult emitOMPTaskyieldDirective(const OMPTaskyieldDirective &S);
mlir::LogicalResult emitOMPBarrierDirective(const OMPBarrierDirective &S);

// CUDA gen functions:
void emitCUDADeviceStubBody(cir::FuncOp fn, FunctionArgList &args);

LValue emitOpaqueValueLValue(const OpaqueValueExpr *e);

/// Emit code to compute a designator that specifies the location
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_clang_library(clangCIR
CIRGenClass.cpp
CIRGenCleanup.cpp
CIRGenCoroutine.cpp
CIRGenCUDA.cpp
CIRGenDecl.cpp
CIRGenDeclCXX.cpp
CIRGenException.cpp
Expand Down
26 changes: 14 additions & 12 deletions clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#include "../Inputs/cuda.h"

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-cir %s -o %t.cir
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s

// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s

// Attribute for global_fn
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}

__host__ void host_fn(int *a, int *b, int *c) {}
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
Expand All @@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {}
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
// CIR-DEVICE: cir.func @_Z9device_fnPidf

#ifdef __CUDA_ARCH__
__global__ void global_fn() {}
#else
__global__ void global_fn();
#endif
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
// CIR-DEVICE: @_Z9global_fnv
__global__ void global_fn(int a) {}
// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
// CIR-DEVICE: @_Z9global_fni

// Make sure `global_fn` indeed gets emitted
__host__ void x() { auto v = global_fn; }
// Check for device stub emission.

// CIR-HOST: cir.alloca {{.*}}"kernel_args"
// CIR-HOST: cir.call @__cudaPopCallConfiguration
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
// CIR-HOST: cir.call @cudaLaunchKernel

0 comments on commit 3ad22db

Please sign in to comment.