Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][Dialect] Emit OpenCL kernel metadata #705

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -979,4 +979,6 @@ def BitfieldInfoAttr : CIR_Attr<"BitfieldInfo", "bitfield_info"> {
];
}

include "clang/CIR/Dialect/IR/CIROpenCLAttrs.td"

#endif // MLIR_CIR_DIALECT_CIR_ATTRS
95 changes: 95 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROpenCLAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===- CIROpenCLAttrs.td - CIR dialect attrs for OpenCL ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the CIR dialect attributes for OpenCL.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
#define MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS

//===----------------------------------------------------------------------===//
// OpenCLKernelMetadataAttr
//===----------------------------------------------------------------------===//

def OpenCLKernelMetadataAttr
: CIR_Attr<"OpenCLKernelMetadata", "cl.kernel_metadata"> {

let summary = "OpenCL kernel metadata";
let description = [{
Provide the required information of an OpenCL kernel for the SPIR-V backend.

The `work_group_size_hint` and `reqd_work_group_size` parameter are integer
arrays with 3 elements that provide hints for the work-group size and the
required work-group size, respectively.

The `vec_type_hint` parameter is a type attribute that provides a hint for
the vectorization. It can be a CIR or LLVM type, depending on the lowering
stage.

The `vec_type_hint_signedness` parameter is a boolean that indicates the
signedness of the vector type hint. It's useful when LLVM type is set in
`vec_type_hint`, which is signless by design. It should be set if and only
if the `vec_type_hint` is present.

The `intel_reqd_sub_group_size` parameter is an integer that restricts the
sub-group size to the specified value.

Example:
```
#fn_attr = #cir<extra({cl.kernel_metadata = #cir.cl.kernel_metadata<
work_group_size_hint = [8 : i32, 16 : i32, 32 : i32],
reqd_work_group_size = [1 : i32, 2 : i32, 4 : i32],
vec_type_hint = !s32i,
vec_type_hint_signedness = 1,
intel_reqd_sub_group_size = 8 : i32
>})>

cir.func @kernel(%arg0: !s32i) extra(#fn_attr) {
cir.return
}
```
}];

let parameters = (ins
OptionalParameter<"ArrayAttr">:$work_group_size_hint,
OptionalParameter<"ArrayAttr">:$reqd_work_group_size,
OptionalParameter<"TypeAttr">:$vec_type_hint,
OptionalParameter<"std::optional<bool>">:$vec_type_hint_signedness,
OptionalParameter<"IntegerAttr">:$intel_reqd_sub_group_size
);

let assemblyFormat = "`<` struct(params) `>`";

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Extract the signedness from int or int vector types.
static std::optional<bool> isSignedHint(mlir::Type vecTypeHint);
}];

let extraClassDefinition = [{
std::optional<bool> $cppClass::isSignedHint(mlir::Type hintQTy) {
// Only types in CIR carry signedness
if (!mlir::isa<mlir::cir::CIRDialect>(hintQTy.getDialect()))
return std::nullopt;

// See also clang::CodeGen::CodeGenFunction::EmitKernelMetadata
auto hintEltQTy = mlir::dyn_cast<mlir::cir::VectorType>(hintQTy);
auto isCIRSignedIntType = [](mlir::Type t) {
return mlir::isa<mlir::cir::IntType>(t) &&
mlir::cast<mlir::cir::IntType>(t).isSigned();
};
return isCIRSignedIntType(hintQTy) ||
(hintEltQTy && isCIRSignedIntType(hintEltQTy.getEltType()));
}
}];

}

#endif // MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
1 change: 1 addition & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ struct MissingFeatures {
static bool getFPFeaturesInEffect() { return false; }
static bool cxxABI() { return false; }
static bool openCL() { return false; }
static bool openCLGenKernelMetadata() { return false; }
static bool CUDA() { return false; }
static bool openMP() { return false; }
static bool openMPRuntime() { return false; }
Expand Down
67 changes: 65 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,7 @@ void CIRGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
llvm_unreachable("NYI");

if (FD && getLangOpts().OpenCL) {
// TODO(cir): Emit OpenCL kernel metadata
assert(!MissingFeatures::openCL());
buildKernelMetadata(FD, Fn);
}

// If we are checking function types, emit a function type signature as
Expand Down Expand Up @@ -1720,3 +1719,67 @@ CIRGenFunction::buildArrayLength(const clang::ArrayType *origArrayType,

return numElements;
}

void CIRGenFunction::buildKernelMetadata(const FunctionDecl *FD,
mlir::cir::FuncOp Fn) {
if (!FD->hasAttr<OpenCLKernelAttr>() && !FD->hasAttr<CUDAGlobalAttr>())
return;

// TODO(cir): CGM.genKernelArgMetadata(Fn, FD, this);
assert(!MissingFeatures::openCLGenKernelMetadata());

if (!getLangOpts().OpenCL)
return;

using mlir::cir::OpenCLKernelMetadataAttr;

mlir::ArrayAttr workGroupSizeHintAttr, reqdWorkGroupSizeAttr;
mlir::TypeAttr vecTypeHintAttr;
std::optional<bool> vecTypeHintSignedness;
mlir::IntegerAttr intelReqdSubGroupSizeAttr;

if (const VecTypeHintAttr *A = FD->getAttr<VecTypeHintAttr>()) {
mlir::Type typeHintValue = getTypes().ConvertType(A->getTypeHint());
vecTypeHintAttr = mlir::TypeAttr::get(typeHintValue);
vecTypeHintSignedness =
OpenCLKernelMetadataAttr::isSignedHint(typeHintValue);
}

if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
workGroupSizeHintAttr = builder.getI32ArrayAttr({
static_cast<int32_t>(A->getXDim()),
static_cast<int32_t>(A->getYDim()),
static_cast<int32_t>(A->getZDim()),
});
}

if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
reqdWorkGroupSizeAttr = builder.getI32ArrayAttr({
static_cast<int32_t>(A->getXDim()),
static_cast<int32_t>(A->getYDim()),
static_cast<int32_t>(A->getZDim()),
});
}

if (const OpenCLIntelReqdSubGroupSizeAttr *A =
FD->getAttr<OpenCLIntelReqdSubGroupSizeAttr>()) {
intelReqdSubGroupSizeAttr = builder.getI32IntegerAttr(A->getSubGroupSize());
}

// Skip the metadata attr if no hints are present.
if (!vecTypeHintAttr && !workGroupSizeHintAttr && !reqdWorkGroupSizeAttr &&
!intelReqdSubGroupSizeAttr)
return;

// Append the kernel metadata to the extra attributes dictionary.
mlir::NamedAttrList attrs;
attrs.append(Fn.getExtraAttrs().getElements());

auto kernelMetadataAttr = OpenCLKernelMetadataAttr::get(
builder.getContext(), workGroupSizeHintAttr, reqdWorkGroupSizeAttr,
vecTypeHintAttr, vecTypeHintSignedness, intelReqdSubGroupSizeAttr);
attrs.append(kernelMetadataAttr.getMnemonic(), kernelMetadataAttr);

Fn.setExtraAttrsAttr(mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), attrs.getDictionary(builder.getContext())));
}
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class CIRGenFunction : public CIRGenTypeCache {
// enter/leave scopes.
llvm::DenseMap<const Expr *, mlir::Value> VLASizeMap;

/// Add OpenCL kernel arg metadata and the kernel attribute metadata to
/// the function metadata.
void buildKernelMetadata(const FunctionDecl *FD, mlir::cir::FuncOp Fn);

public:
/// A non-RAII class containing all the information about a bound
/// opaque value. OpaqueValueMapping, below, is a RAII wrapper for
Expand Down
55 changes: 55 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
Expand Down Expand Up @@ -498,6 +499,60 @@ LogicalResult DynamicCastInfoAttr::verify(
return success();
}

//===----------------------------------------------------------------------===//
// OpenCLKernelMetadataAttr definitions
//===----------------------------------------------------------------------===//

LogicalResult OpenCLKernelMetadataAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
ArrayAttr workGroupSizeHint, ArrayAttr reqdWorkGroupSize,
TypeAttr vecTypeHint, std::optional<bool> vecTypeHintSignedness,
IntegerAttr intelReqdSubGroupSize) {
// If no field is present, the attribute is considered invalid.
if (!workGroupSizeHint && !reqdWorkGroupSize && !vecTypeHint &&
!vecTypeHintSignedness && !intelReqdSubGroupSize) {
return emitError()
<< "metadata attribute without any field present is invalid";
}

// Check for 3-dim integer tuples
auto is3dimIntTuple = [](ArrayAttr arr) {
auto isInt = [](Attribute dim) { return mlir::isa<IntegerAttr>(dim); };
return arr.size() == 3 && llvm::all_of(arr, isInt);
};
if (workGroupSizeHint && !is3dimIntTuple(workGroupSizeHint)) {
return emitError()
<< "work_group_size_hint must have exactly 3 integer elements";
}
if (reqdWorkGroupSize && !is3dimIntTuple(reqdWorkGroupSize)) {
return emitError()
<< "reqd_work_group_size must have exactly 3 integer elements";
}

// Check for co-presence of vecTypeHintSignedness
if (!!vecTypeHint != vecTypeHintSignedness.has_value()) {
return emitError() << "vec_type_hint_signedness should be present if and "
"only if vec_type_hint is set";
}

if (vecTypeHint) {
Type vecTypeHintValue = vecTypeHint.getValue();
if (mlir::isa<cir::CIRDialect>(vecTypeHintValue.getDialect())) {
// Check for signedness alignment in CIR
if (isSignedHint(vecTypeHintValue) != vecTypeHintSignedness) {
return emitError() << "vec_type_hint_signedness must match the "
"signedness of the vec_type_hint type";
}
// Check for the dialect of type hint
} else if (!LLVM::isCompatibleType(vecTypeHintValue)) {
return emitError() << "vec_type_hint must be a type from the CIR or LLVM "
"dialect";
}
}

return success();
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
Expand Down
41 changes: 38 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,12 +1523,13 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
/// to the name of the attribute in ODS.
static StringRef getLinkageAttrNameString() { return "linkage"; }

/// Convert the `cir.func` attributes to `llvm.func` attributes.
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out
/// argument attributes.
void
filterFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs,
SmallVectorImpl<mlir::NamedAttribute> &result) const {
lowerFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs,
SmallVectorImpl<mlir::NamedAttribute> &result) const {
for (auto attr : func->getAttrs()) {
if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
attr.getName() == func.getFunctionTypeAttrName() ||
Expand All @@ -1543,11 +1544,45 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
if (attr.getName() == func.getExtraAttrsAttrName()) {
std::string cirName = "cir." + func.getExtraAttrsAttrName().str();
attr.setName(mlir::StringAttr::get(getContext(), cirName));

lowerFuncOpenCLKernelMetadata(attr);
}
result.push_back(attr);
}
}

/// When do module translation, we can only translate LLVM-compatible types.
/// Here we lower possible OpenCLKernelMetadataAttr to use the converted type.
void
lowerFuncOpenCLKernelMetadata(mlir::NamedAttribute &extraAttrsEntry) const {
const auto attrKey = mlir::cir::OpenCLKernelMetadataAttr::getMnemonic();
auto oldExtraAttrs =
cast<mlir::cir::ExtraFuncAttributesAttr>(extraAttrsEntry.getValue());
if (!oldExtraAttrs.getElements().contains(attrKey))
return;

mlir::NamedAttrList newExtraAttrs;
for (auto entry : oldExtraAttrs.getElements()) {
if (entry.getName() == attrKey) {
auto clKernelMetadata =
cast<mlir::cir::OpenCLKernelMetadataAttr>(entry.getValue());
if (auto vecTypeHint = clKernelMetadata.getVecTypeHint()) {
auto newType = typeConverter->convertType(vecTypeHint.getValue());
auto newTypeHint = mlir::TypeAttr::get(newType);
auto newCLKMAttr = mlir::cir::OpenCLKernelMetadataAttr::get(
getContext(), clKernelMetadata.getWorkGroupSizeHint(),
clKernelMetadata.getReqdWorkGroupSize(), newTypeHint,
clKernelMetadata.getVecTypeHintSignedness(),
clKernelMetadata.getIntelReqdSubGroupSize());
entry.setValue(newCLKMAttr);
}
}
newExtraAttrs.push_back(entry);
}
extraAttrsEntry.setValue(mlir::cir::ExtraFuncAttributesAttr::get(
getContext(), newExtraAttrs.getDictionary(getContext())));
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::FuncOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1585,7 +1620,7 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {

auto linkage = convertLinkage(op.getLinkage());
SmallVector<mlir::NamedAttribute, 4> attributes;
filterFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);

auto fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, mlir::LLVM::CConv::C,
Expand Down
Loading
Loading