Skip to content

Commit

Permalink
[CIR][CodeGen][LowerToLLVM] Emit OpenCL version metadata for SPIR-V t…
Browse files Browse the repository at this point in the history
…arget (llvm#773)

Similar to llvm#767, this PR emit the module level OpenCL version metadata
following the OG CodeGen skeleton.

We use a full qualified `cir.cl.version` attribute on the module op to
store the info in CIR.
  • Loading branch information
seven-mile authored and smeenai committed Oct 9, 2024
1 parent 6363d87 commit 32dc02e
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 0 deletions.
21 changes: 21 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROpenCLAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,25 @@ def OpenCLKernelArgMetadataAttr
let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// OpenCLVersionAttr
//===----------------------------------------------------------------------===//

def OpenCLVersionAttr : CIR_Attr<"OpenCLVersion", "cl.version"> {
let summary = "OpenCL version";
let parameters = (ins "int32_t":$major, "int32_t":$minor);
let description = [{
Represents the version of OpenCL.

Example:
```
// Module compiled from OpenCL 1.2.
module attributes {cir.cl.version = cir.cl.version<1, 2>} {}
// Module compiled from OpenCL 3.0.
module attributes {cir.cl.version = cir.cl.version<3, 0>} {}
```
}];
let assemblyFormat = "`<` $major `,` $minor `>`";
}

#endif // MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
24 changes: 24 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2768,6 +2768,16 @@ void CIRGenModule::Release() {
// TODO: buildModuleLinkOptions
}

// Emit OpenCL specific module metadata: OpenCL/SPIR version.
if (langOpts.CUDAIsDevice && getTriple().isSPIRV())
llvm_unreachable("CUDA SPIR-V NYI");
if (langOpts.OpenCL) {
buildOpenCLMetadata();
// Emit SPIR version.
if (getTriple().isSPIR())
llvm_unreachable("SPIR target NYI");
}

// TODO: FINISH THE REST OF THIS
}

Expand Down Expand Up @@ -3235,3 +3245,17 @@ void CIRGenModule::genKernelArgMetadata(mlir::cir::FuncOp Fn,
llvm_unreachable("NYI HIPSaveKernelArgName");
}
}

void CIRGenModule::buildOpenCLMetadata() {
// SPIR v2.0 s2.13 - The OpenCL version used by the module is stored in the
// opencl.ocl.version named metadata node.
// C++ for OpenCL has a distinct mapping for versions compatibile with OpenCL.
unsigned version = langOpts.getOpenCLCompatibleVersion();
unsigned major = version / 100;
unsigned minor = (version % 100) / 10;

auto clVersionAttr =
mlir::cir::OpenCLVersionAttr::get(builder.getContext(), major, minor);

theModule->setAttr("cir.cl.version", clVersionAttr);
}
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,9 @@ class CIRGenModule : public CIRGenTypeCache {
const FunctionDecl *FD = nullptr,
CIRGenFunction *CGF = nullptr);

/// Emits OpenCL specific Metadata e.g. OpenCL version.
void buildOpenCLMetadata();

private:
// An ordered map of canonical GlobalDecls to their mangled names.
llvm::MapVector<clang::GlobalDecl, llvm::StringRef> MangledDeclNames;
Expand Down
25 changes: 25 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVMIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class CIRDialectLLVMIRTranslationInterface
mlir::LLVM::ModuleTranslation &moduleTranslation) const override {
if (auto func = dyn_cast<mlir::LLVM::LLVMFuncOp>(op)) {
amendFunction(func, instructions, attribute, moduleTranslation);
} else if (auto mod = dyn_cast<mlir::ModuleOp>(op)) {
amendModule(mod, attribute, moduleTranslation);
}
return mlir::success();
}
Expand All @@ -60,6 +62,29 @@ class CIRDialectLLVMIRTranslationInterface
}

private:
// Translate CIR's module attributes to LLVM's module metadata
void amendModule(mlir::ModuleOp module, mlir::NamedAttribute attribute,
mlir::LLVM::ModuleTranslation &moduleTranslation) const {
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
llvm::LLVMContext &llvmContext = llvmModule->getContext();

if (auto openclVersionAttr = mlir::dyn_cast<mlir::cir::OpenCLVersionAttr>(
attribute.getValue())) {
auto *int32Ty = llvm::IntegerType::get(llvmContext, 32);
llvm::Metadata *oclVerElts[] = {
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(int32Ty, openclVersionAttr.getMajor())),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(int32Ty, openclVersionAttr.getMinor()))};
llvm::NamedMDNode *oclVerMD =
llvmModule->getOrInsertNamedMetadata("opencl.ocl.version");
oclVerMD->addOperand(llvm::MDNode::get(llvmContext, oclVerElts));
}

// Drop ammended CIR attribute from LLVM op.
module->removeAttr(attribute.getName());
}

// Translate CIR's extra function attributes to LLVM's function attributes.
void amendFunction(mlir::LLVM::LLVMFuncOp func,
llvm::ArrayRef<llvm::Instruction *> instructions,
Expand Down
16 changes: 16 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/opencl-version.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-cir -triple spirv64-unknown-unknown %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR-CL30
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-llvm -triple spirv64-unknown-unknown %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM-CL30
// RUN: %clang_cc1 -cl-std=CL1.2 -O0 -fclangir -emit-cir -triple spirv64-unknown-unknown %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR-CL12
// RUN: %clang_cc1 -cl-std=CL1.2 -O0 -fclangir -emit-llvm -triple spirv64-unknown-unknown %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM-CL12

// CIR-CL30: module {{.*}} attributes {{{.*}}cir.cl.version = #cir.cl.version<3, 0>
// LLVM-CL30: !opencl.ocl.version = !{![[MDCL30:[0-9]+]]}
// LLVM-CL30: ![[MDCL30]] = !{i32 3, i32 0}

// CIR-CL12: module {{.*}} attributes {{{.*}}cir.cl.version = #cir.cl.version<1, 2>
// LLVM-CL12: !opencl.ocl.version = !{![[MDCL12:[0-9]+]]}
// LLVM-CL12: ![[MDCL12]] = !{i32 1, i32 2}

0 comments on commit 32dc02e

Please sign in to comment.