Skip to content

Commit

Permalink
[CIR][Dialect] Emit OpenCL kernel metadata (llvm#705)
Browse files Browse the repository at this point in the history
This PR introduces a new attribute `OpenCLKernelMetadataAttr` to model
the OpenCL kernel metadata structurally in CIR, with its corresponding
implementations of CodeGen, Lowering and Translation.

The `"TypeAttr":$vec_type_hint` part is tricky because of the absence of
the signless feature of LLVM IR, while SPIR-V requires it. According to
the spec, the final LLVM IR should encode signedness with an extra `i32`
boolean value.

In this PR, the droping logic from CIR's `TypeConverter` is still used
to avoid code duplication when lowering to LLVM dialect. However, the
signedness is then restored (still capsuled by a CIR attribute) and
dropped again in the translation into LLVM IR.
  • Loading branch information
seven-mile authored and lanza committed Oct 2, 2024
1 parent e2599f2 commit 6e11fd3
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 5 deletions.
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -979,4 +979,6 @@ def BitfieldInfoAttr : CIR_Attr<"BitfieldInfo", "bitfield_info"> {
];
}

include "clang/CIR/Dialect/IR/CIROpenCLAttrs.td"

#endif // MLIR_CIR_DIALECT_CIR_ATTRS
95 changes: 95 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROpenCLAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===- CIROpenCLAttrs.td - CIR dialect attrs for OpenCL ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the CIR dialect attributes for OpenCL.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
#define MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS

//===----------------------------------------------------------------------===//
// OpenCLKernelMetadataAttr
//===----------------------------------------------------------------------===//

def OpenCLKernelMetadataAttr
: CIR_Attr<"OpenCLKernelMetadata", "cl.kernel_metadata"> {

let summary = "OpenCL kernel metadata";
let description = [{
Provide the required information of an OpenCL kernel for the SPIR-V backend.

The `work_group_size_hint` and `reqd_work_group_size` parameter are integer
arrays with 3 elements that provide hints for the work-group size and the
required work-group size, respectively.

The `vec_type_hint` parameter is a type attribute that provides a hint for
the vectorization. It can be a CIR or LLVM type, depending on the lowering
stage.

The `vec_type_hint_signedness` parameter is a boolean that indicates the
signedness of the vector type hint. It's useful when LLVM type is set in
`vec_type_hint`, which is signless by design. It should be set if and only
if the `vec_type_hint` is present.

The `intel_reqd_sub_group_size` parameter is an integer that restricts the
sub-group size to the specified value.

Example:
```
#fn_attr = #cir<extra({cl.kernel_metadata = #cir.cl.kernel_metadata<
work_group_size_hint = [8 : i32, 16 : i32, 32 : i32],
reqd_work_group_size = [1 : i32, 2 : i32, 4 : i32],
vec_type_hint = !s32i,
vec_type_hint_signedness = 1,
intel_reqd_sub_group_size = 8 : i32
>})>

cir.func @kernel(%arg0: !s32i) extra(#fn_attr) {
cir.return
}
```
}];

let parameters = (ins
OptionalParameter<"ArrayAttr">:$work_group_size_hint,
OptionalParameter<"ArrayAttr">:$reqd_work_group_size,
OptionalParameter<"TypeAttr">:$vec_type_hint,
OptionalParameter<"std::optional<bool>">:$vec_type_hint_signedness,
OptionalParameter<"IntegerAttr">:$intel_reqd_sub_group_size
);

let assemblyFormat = "`<` struct(params) `>`";

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Extract the signedness from int or int vector types.
static std::optional<bool> isSignedHint(mlir::Type vecTypeHint);
}];

let extraClassDefinition = [{
std::optional<bool> $cppClass::isSignedHint(mlir::Type hintQTy) {
// Only types in CIR carry signedness
if (!mlir::isa<mlir::cir::CIRDialect>(hintQTy.getDialect()))
return std::nullopt;

// See also clang::CodeGen::CodeGenFunction::EmitKernelMetadata
auto hintEltQTy = mlir::dyn_cast<mlir::cir::VectorType>(hintQTy);
auto isCIRSignedIntType = [](mlir::Type t) {
return mlir::isa<mlir::cir::IntType>(t) &&
mlir::cast<mlir::cir::IntType>(t).isSigned();
};
return isCIRSignedIntType(hintQTy) ||
(hintEltQTy && isCIRSignedIntType(hintEltQTy.getEltType()));
}
}];

}

#endif // MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
1 change: 1 addition & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ struct MissingFeatures {
static bool getFPFeaturesInEffect() { return false; }
static bool cxxABI() { return false; }
static bool openCL() { return false; }
static bool openCLGenKernelMetadata() { return false; }
static bool CUDA() { return false; }
static bool openMP() { return false; }
static bool openMPRuntime() { return false; }
Expand Down
67 changes: 65 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,7 @@ void CIRGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
llvm_unreachable("NYI");

if (FD && getLangOpts().OpenCL) {
// TODO(cir): Emit OpenCL kernel metadata
assert(!MissingFeatures::openCL());
buildKernelMetadata(FD, Fn);
}

// If we are checking function types, emit a function type signature as
Expand Down Expand Up @@ -1720,3 +1719,67 @@ CIRGenFunction::buildArrayLength(const clang::ArrayType *origArrayType,

return numElements;
}

void CIRGenFunction::buildKernelMetadata(const FunctionDecl *FD,
mlir::cir::FuncOp Fn) {
if (!FD->hasAttr<OpenCLKernelAttr>() && !FD->hasAttr<CUDAGlobalAttr>())
return;

// TODO(cir): CGM.genKernelArgMetadata(Fn, FD, this);
assert(!MissingFeatures::openCLGenKernelMetadata());

if (!getLangOpts().OpenCL)
return;

using mlir::cir::OpenCLKernelMetadataAttr;

mlir::ArrayAttr workGroupSizeHintAttr, reqdWorkGroupSizeAttr;
mlir::TypeAttr vecTypeHintAttr;
std::optional<bool> vecTypeHintSignedness;
mlir::IntegerAttr intelReqdSubGroupSizeAttr;

if (const VecTypeHintAttr *A = FD->getAttr<VecTypeHintAttr>()) {
mlir::Type typeHintValue = getTypes().ConvertType(A->getTypeHint());
vecTypeHintAttr = mlir::TypeAttr::get(typeHintValue);
vecTypeHintSignedness =
OpenCLKernelMetadataAttr::isSignedHint(typeHintValue);
}

if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
workGroupSizeHintAttr = builder.getI32ArrayAttr({
static_cast<int32_t>(A->getXDim()),
static_cast<int32_t>(A->getYDim()),
static_cast<int32_t>(A->getZDim()),
});
}

if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
reqdWorkGroupSizeAttr = builder.getI32ArrayAttr({
static_cast<int32_t>(A->getXDim()),
static_cast<int32_t>(A->getYDim()),
static_cast<int32_t>(A->getZDim()),
});
}

if (const OpenCLIntelReqdSubGroupSizeAttr *A =
FD->getAttr<OpenCLIntelReqdSubGroupSizeAttr>()) {
intelReqdSubGroupSizeAttr = builder.getI32IntegerAttr(A->getSubGroupSize());
}

// Skip the metadata attr if no hints are present.
if (!vecTypeHintAttr && !workGroupSizeHintAttr && !reqdWorkGroupSizeAttr &&
!intelReqdSubGroupSizeAttr)
return;

// Append the kernel metadata to the extra attributes dictionary.
mlir::NamedAttrList attrs;
attrs.append(Fn.getExtraAttrs().getElements());

auto kernelMetadataAttr = OpenCLKernelMetadataAttr::get(
builder.getContext(), workGroupSizeHintAttr, reqdWorkGroupSizeAttr,
vecTypeHintAttr, vecTypeHintSignedness, intelReqdSubGroupSizeAttr);
attrs.append(kernelMetadataAttr.getMnemonic(), kernelMetadataAttr);

Fn.setExtraAttrsAttr(mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), attrs.getDictionary(builder.getContext())));
}
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class CIRGenFunction : public CIRGenTypeCache {
// enter/leave scopes.
llvm::DenseMap<const Expr *, mlir::Value> VLASizeMap;

/// Add OpenCL kernel arg metadata and the kernel attribute metadata to
/// the function metadata.
void buildKernelMetadata(const FunctionDecl *FD, mlir::cir::FuncOp Fn);

public:
/// A non-RAII class containing all the information about a bound
/// opaque value. OpaqueValueMapping, below, is a RAII wrapper for
Expand Down
55 changes: 55 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
Expand Down Expand Up @@ -499,6 +500,60 @@ LogicalResult DynamicCastInfoAttr::verify(
return success();
}

//===----------------------------------------------------------------------===//
// OpenCLKernelMetadataAttr definitions
//===----------------------------------------------------------------------===//

LogicalResult OpenCLKernelMetadataAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
ArrayAttr workGroupSizeHint, ArrayAttr reqdWorkGroupSize,
TypeAttr vecTypeHint, std::optional<bool> vecTypeHintSignedness,
IntegerAttr intelReqdSubGroupSize) {
// If no field is present, the attribute is considered invalid.
if (!workGroupSizeHint && !reqdWorkGroupSize && !vecTypeHint &&
!vecTypeHintSignedness && !intelReqdSubGroupSize) {
return emitError()
<< "metadata attribute without any field present is invalid";
}

// Check for 3-dim integer tuples
auto is3dimIntTuple = [](ArrayAttr arr) {
auto isInt = [](Attribute dim) { return mlir::isa<IntegerAttr>(dim); };
return arr.size() == 3 && llvm::all_of(arr, isInt);
};
if (workGroupSizeHint && !is3dimIntTuple(workGroupSizeHint)) {
return emitError()
<< "work_group_size_hint must have exactly 3 integer elements";
}
if (reqdWorkGroupSize && !is3dimIntTuple(reqdWorkGroupSize)) {
return emitError()
<< "reqd_work_group_size must have exactly 3 integer elements";
}

// Check for co-presence of vecTypeHintSignedness
if (!!vecTypeHint != vecTypeHintSignedness.has_value()) {
return emitError() << "vec_type_hint_signedness should be present if and "
"only if vec_type_hint is set";
}

if (vecTypeHint) {
Type vecTypeHintValue = vecTypeHint.getValue();
if (mlir::isa<cir::CIRDialect>(vecTypeHintValue.getDialect())) {
// Check for signedness alignment in CIR
if (isSignedHint(vecTypeHintValue) != vecTypeHintSignedness) {
return emitError() << "vec_type_hint_signedness must match the "
"signedness of the vec_type_hint type";
}
// Check for the dialect of type hint
} else if (!LLVM::isCompatibleType(vecTypeHintValue)) {
return emitError() << "vec_type_hint must be a type from the CIR or LLVM "
"dialect";
}
}

return success();
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
Expand Down
41 changes: 38 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,12 +1523,13 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
/// to the name of the attribute in ODS.
static StringRef getLinkageAttrNameString() { return "linkage"; }

/// Convert the `cir.func` attributes to `llvm.func` attributes.
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out
/// argument attributes.
void
filterFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs,
SmallVectorImpl<mlir::NamedAttribute> &result) const {
lowerFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs,
SmallVectorImpl<mlir::NamedAttribute> &result) const {
for (auto attr : func->getAttrs()) {
if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
attr.getName() == func.getFunctionTypeAttrName() ||
Expand All @@ -1543,11 +1544,45 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
if (attr.getName() == func.getExtraAttrsAttrName()) {
std::string cirName = "cir." + func.getExtraAttrsAttrName().str();
attr.setName(mlir::StringAttr::get(getContext(), cirName));

lowerFuncOpenCLKernelMetadata(attr);
}
result.push_back(attr);
}
}

/// When do module translation, we can only translate LLVM-compatible types.
/// Here we lower possible OpenCLKernelMetadataAttr to use the converted type.
void
lowerFuncOpenCLKernelMetadata(mlir::NamedAttribute &extraAttrsEntry) const {
const auto attrKey = mlir::cir::OpenCLKernelMetadataAttr::getMnemonic();
auto oldExtraAttrs =
cast<mlir::cir::ExtraFuncAttributesAttr>(extraAttrsEntry.getValue());
if (!oldExtraAttrs.getElements().contains(attrKey))
return;

mlir::NamedAttrList newExtraAttrs;
for (auto entry : oldExtraAttrs.getElements()) {
if (entry.getName() == attrKey) {
auto clKernelMetadata =
cast<mlir::cir::OpenCLKernelMetadataAttr>(entry.getValue());
if (auto vecTypeHint = clKernelMetadata.getVecTypeHint()) {
auto newType = typeConverter->convertType(vecTypeHint.getValue());
auto newTypeHint = mlir::TypeAttr::get(newType);
auto newCLKMAttr = mlir::cir::OpenCLKernelMetadataAttr::get(
getContext(), clKernelMetadata.getWorkGroupSizeHint(),
clKernelMetadata.getReqdWorkGroupSize(), newTypeHint,
clKernelMetadata.getVecTypeHintSignedness(),
clKernelMetadata.getIntelReqdSubGroupSize());
entry.setValue(newCLKMAttr);
}
}
newExtraAttrs.push_back(entry);
}
extraAttrsEntry.setValue(mlir::cir::ExtraFuncAttributesAttr::get(
getContext(), newExtraAttrs.getDictionary(getContext())));
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::FuncOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1585,7 +1620,7 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {

auto linkage = convertLinkage(op.getLinkage());
SmallVector<mlir::NamedAttribute, 4> attributes;
filterFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);

auto fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, mlir::LLVM::CConv::C,
Expand Down
Loading

0 comments on commit 6e11fd3

Please sign in to comment.