Skip to content

Commit

Permalink
[CIR][CUDA] Generate attribute for kernel name of device stubs (#1317)
Browse files Browse the repository at this point in the history
Now a `__global__` function on host will be generated to a device stub,
with an attribute recording the corresponding kernel name (mangled name
on device of the same function). The dynamic registration phase will be
implemented in LLVM lowering.

For example, CIR generated for `__global__ void global_fn();` looks like
```
#fn_attr1 = #cir<extra({cuda_kernel_name = #cir.cuda_kernel_name<_Z9global_fnv>})>
cir.func private @_Z24__device_stub__global_fnv() extra(#fn_attr1)
```
  • Loading branch information
AdUhTkJm authored Feb 10, 2025
1 parent 53335ae commit 2294d5f
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 31 deletions.
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1327,5 +1327,6 @@ def CIR_TBAAAttr : CIR_Attr<"TBAA", "tbaa", []> {
}

include "clang/CIR/Dialect/IR/CIROpenCLAttrs.td"
include "clang/CIR/Dialect/IR/CIRCUDAAttrs.td"

#endif // MLIR_CIR_DIALECT_CIR_ATTRS
38 changes: 38 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRCUDAAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===---- CIRCUDAAttrs.td - CIR dialect attrs for CUDA -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the CIR dialect attributes for OpenCL.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CIR_DIALECT_CIR_CUDA_ATTRS
#define MLIR_CIR_DIALECT_CIR_CUDA_ATTRS

//===----------------------------------------------------------------------===//
// CUDAKernelNameAttr
//===----------------------------------------------------------------------===//

def CUDAKernelNameAttr : CIR_Attr<"CUDAKernelName",
"cuda_kernel_name"> {
let summary = "Device-side function name for this stub.";
let description =
[{
This attribute is attached to function definitions and records the
mangled name of the kernel function used on the device.

In CUDA, global functions (kernels) are processed differently for host
and device. On host, Clang generates device stubs; on device, they are
treated as normal functions. As they probably have different mangled
names, we must record the corresponding device-side name for a stub.
}];

let parameters = (ins "std::string":$kernel_name);
let assemblyFormat = "`<` $kernel_name `>`";
}

#endif // MLIR_CIR_DIALECT_CIR_CUDA_ATTRS
10 changes: 10 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ void CIRGenModule::constructAttributeList(
getLangOpts().OffloadUniformBlock)
assert(!cir::MissingFeatures::CUDA());

if (langOpts.CUDA && !langOpts.CUDAIsDevice &&
TargetDecl->hasAttr<CUDAGlobalAttr>()) {
GlobalDecl kernel(CalleeInfo.getCalleeDecl());
llvm::StringRef kernelName = getMangledName(
kernel.getWithKernelReferenceKind(KernelReferenceKind::Kernel));
auto attr =
cir::CUDAKernelNameAttr::get(&getMLIRContext(), kernelName.str());
funcAttrs.set(attr.getMnemonic(), attr);
}

if (TargetDecl->hasAttr<ArmLocallyStreamingAttr>())
;
}
Expand Down
18 changes: 8 additions & 10 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// This is the internal per-translation-unit state used for CIR translation.
//
//===----------------------------------------------------------------------===//
#include "CIRGenModule.h"
#include "CIRGenCXXABI.h"
#include "CIRGenCstEmitter.h"
#include "CIRGenFunction.h"
Expand Down Expand Up @@ -528,10 +527,9 @@ void CIRGenModule::emitGlobal(GlobalDecl GD) {
if (langOpts.HIPStdPar)
llvm_unreachable("NYI");

if (Global->hasAttr<CUDAGlobalAttr>())
llvm_unreachable("NYI");

if (!Global->hasAttr<CUDADeviceAttr>())
// Global functions reside on device, so it shouldn't be skipped.
if (!Global->hasAttr<CUDAGlobalAttr>() &&
!Global->hasAttr<CUDADeviceAttr>())
return;
} else {
// We must skip __device__ functions when compiling for host.
Expand Down Expand Up @@ -2352,10 +2350,10 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl GD, mlir::Type Ty,
auto F = GetOrCreateCIRFunction(MangledName, Ty, GD, ForVTable, DontDefer,
/*IsThunk=*/false, IsForDefinition);

// As __global__ functions always reside on device,
// we need special care when accessing them from host;
// otherwise, CUDA functions behave as normal functions
if (langOpts.CUDA && !langOpts.CUDAIsDevice &&
// As __global__ functions (kernels) always reside on device,
// when we access them from host, we must refer to the kernel handle.
// For CUDA, it's just the device stub. For HIP, it's something different.
if (langOpts.CUDA && !langOpts.CUDAIsDevice && langOpts.HIP &&
cast<FunctionDecl>(GD.getDecl())->hasAttr<CUDAGlobalAttr>()) {
llvm_unreachable("NYI");
}
Expand Down Expand Up @@ -2398,7 +2396,7 @@ static std::string getMangledNameImpl(CIRGenModule &CGM, GlobalDecl GD,
assert(0 && "NYI");
} else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
assert(0 && "NYI");
Out << "__device_stub__";
} else {
Out << II->getName();
}
Expand Down
14 changes: 0 additions & 14 deletions clang/test/CIR/CodeGen/CUDA/simple-device.cu

This file was deleted.

29 changes: 22 additions & 7 deletions clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
#include "../Inputs/cuda.h"

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-cir %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s

// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
// RUN: -emit-cir %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s

// Attribute for global_fn
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}

// This should emit as a normal C++ function.
__host__ void host_fn(int *a, int *b, int *c) {}
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
// CIR-DEVICE-NOT: cir.func @_Z7host_fnPiS_S_

// CIR: cir.func @_Z7host_fnPiS_S_

// This shouldn't emit.
__device__ void device_fn(int* a, double b, float c) {}
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
// CIR-DEVICE: cir.func @_Z9device_fnPidf

#ifdef __CUDA_ARCH__
__global__ void global_fn() {}
#else
__global__ void global_fn();
#endif
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
// CIR-DEVICE: @_Z9global_fnv

// CHECK-NOT: cir.func @_Z9device_fnPidf
// Make sure `global_fn` indeed gets emitted
__host__ void x() { auto v = global_fn; }

0 comments on commit 2294d5f

Please sign in to comment.