Skip to content

Commit

Permalink
[CIR][CUDA] Add attribute for CUDA fat binary name
Browse files Browse the repository at this point in the history
  • Loading branch information
AdUhTkJm committed Feb 20, 2025
1 parent 2df2022 commit f9247af
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 4 deletions.
19 changes: 18 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRCUDAAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//===----------------------------------------------------------------------===//

def CUDAKernelNameAttr : CIR_Attr<"CUDAKernelName",
"cuda_kernel_name"> {
"cu.kernel_name"> {
let summary = "Device-side function name for this stub.";
let description =
[{
Expand All @@ -35,4 +35,21 @@ def CUDAKernelNameAttr : CIR_Attr<"CUDAKernelName",
let assemblyFormat = "`<` $kernel_name `>`";
}

def CUDABinaryHandleAttr : CIR_Attr<"CUDABinaryHandle",
"cu.binary_handle"> {
let summary = "Fat binary handle for device code.";
let description =
[{
This attribute is attached to the ModuleOp and records the binary file
name passed to host.

CUDA first compiles device-side code into a fat binary file. The file
name is then passed into host-side code, which is used to create a handle
and then generate various registration functions.
}];

let parameters = (ins "std::string":$name);
let assemblyFormat = "`<` $name `>`";
}

#endif // MLIR_CIR_DIALECT_CIR_CUDA_ATTRS
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def CIR_Dialect : Dialect {
static llvm::StringRef getGlobalAnnotationsAttrName() { return "cir.global_annotations"; }

static llvm::StringRef getOpenCLVersionAttrName() { return "cir.cl.version"; }
static llvm::StringRef getCUDABinaryHandleAttrName() { return "cir.cu.binary_handle"; }

void registerAttributes();
void registerTypes();
Expand Down
2 changes: 0 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>

using namespace clang;
using namespace clang::CIRGen;
Expand Down Expand Up @@ -91,7 +90,6 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
llvm_unreachable("NYI");

std::string launchAPI = addPrefixToName("LaunchKernel");
std::cout << "LaunchAPI is " << launchAPI << "\n";
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
FunctionDecl *launchFD = nullptr;
for (auto *result : dc->lookup(&launchII)) {
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,
/*line=*/0,
/*col=*/0));
}

// Set CUDA GPU binary handle.
if (langOpts.CUDA) {
std::string cudaBinaryName = codeGenOpts.CudaGpuBinaryFileName;
if (!cudaBinaryName.empty()) {
theModule->setAttr(
cir::CIRDialect::getCUDABinaryHandleAttrName(),
cir::CUDABinaryHandleAttr::get(&mlirContext, cudaBinaryName));
}
}

if (langOpts.Sanitize.has(SanitizerKind::Thread) ||
(!codeGenOpts.RelaxedAliasing && codeGenOpts.OptimizationLevel > 0)) {
tbaa.reset(new CIRGenTBAA(&mlirContext, astContext, genTypes, theModule,
Expand Down
9 changes: 9 additions & 0 deletions clang/test/CIR/CodeGen/CUDA/registration.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "../Inputs/cuda.h"

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
// RUN: -fcuda-include-gpubinary fatbin.o\
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s

// CIR-HOST: module @"{{.*}}" attributes{{.*}}cir.cu.binary_handle = #cir.cu.binary_handle<fatbin.o>{{.*}}
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s

// Attribute for global_fn
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cu.kernel_name<_Z9global_fni>{{.*}}

__host__ void host_fn(int *a, int *b, int *c) {}
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
Expand Down

0 comments on commit f9247af

Please sign in to comment.