Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Add definition for OpGroupNonUniformBallotBitCount #126055

Merged

Conversation

IgWod-IMG
Copy link
Contributor

A new constraint is also added to restrict attributes values for SPIR-V attributes. Ideally this should use ConfinedAttr with a custom constraint directly on the operand, however it seems TableGen does not allow using that with SPIR-V attributes. I suspect it is because SPIR-V attributes do not derive from the generic MLIR attribute class - TableGen complains about missing enum field.

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

Changes

A new constraint is also added to restrict attributes values for SPIR-V attributes. Ideally this should use ConfinedAttr with a custom constraint directly on the operand, however it seems TableGen does not allow using that with SPIR-V attributes. I suspect it is because SPIR-V attributes do not derive from the generic MLIR attribute class - TableGen complains about missing enum field.


Patch is 46.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126055.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+233-231)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td (+76)
  • (modified) mlir/test/Dialect/SPIRV/IR/group-ops.mlir (+28)
  • (modified) mlir/test/Target/SPIRV/group-ops.mlir (+6-1)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ff738fc2555734a..10ddb490fc167ac 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4301,237 +4301,238 @@ class SPIRV_OpCode<string name, int val> {
 
 // Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
 
-def SPIRV_OC_OpNop                          : I32EnumAttrCase<"OpNop", 0>;
-def SPIRV_OC_OpUndef                        : I32EnumAttrCase<"OpUndef", 1>;
-def SPIRV_OC_OpSourceContinued              : I32EnumAttrCase<"OpSourceContinued", 2>;
-def SPIRV_OC_OpSource                       : I32EnumAttrCase<"OpSource", 3>;
-def SPIRV_OC_OpSourceExtension              : I32EnumAttrCase<"OpSourceExtension", 4>;
-def SPIRV_OC_OpName                         : I32EnumAttrCase<"OpName", 5>;
-def SPIRV_OC_OpMemberName                   : I32EnumAttrCase<"OpMemberName", 6>;
-def SPIRV_OC_OpString                       : I32EnumAttrCase<"OpString", 7>;
-def SPIRV_OC_OpLine                         : I32EnumAttrCase<"OpLine", 8>;
-def SPIRV_OC_OpExtension                    : I32EnumAttrCase<"OpExtension", 10>;
-def SPIRV_OC_OpExtInstImport                : I32EnumAttrCase<"OpExtInstImport", 11>;
-def SPIRV_OC_OpExtInst                      : I32EnumAttrCase<"OpExtInst", 12>;
-def SPIRV_OC_OpMemoryModel                  : I32EnumAttrCase<"OpMemoryModel", 14>;
-def SPIRV_OC_OpEntryPoint                   : I32EnumAttrCase<"OpEntryPoint", 15>;
-def SPIRV_OC_OpExecutionMode                : I32EnumAttrCase<"OpExecutionMode", 16>;
-def SPIRV_OC_OpCapability                   : I32EnumAttrCase<"OpCapability", 17>;
-def SPIRV_OC_OpTypeVoid                     : I32EnumAttrCase<"OpTypeVoid", 19>;
-def SPIRV_OC_OpTypeBool                     : I32EnumAttrCase<"OpTypeBool", 20>;
-def SPIRV_OC_OpTypeInt                      : I32EnumAttrCase<"OpTypeInt", 21>;
-def SPIRV_OC_OpTypeFloat                    : I32EnumAttrCase<"OpTypeFloat", 22>;
-def SPIRV_OC_OpTypeVector                   : I32EnumAttrCase<"OpTypeVector", 23>;
-def SPIRV_OC_OpTypeMatrix                   : I32EnumAttrCase<"OpTypeMatrix", 24>;
-def SPIRV_OC_OpTypeImage                    : I32EnumAttrCase<"OpTypeImage", 25>;
-def SPIRV_OC_OpTypeSampledImage             : I32EnumAttrCase<"OpTypeSampledImage", 27>;
-def SPIRV_OC_OpTypeArray                    : I32EnumAttrCase<"OpTypeArray", 28>;
-def SPIRV_OC_OpTypeRuntimeArray             : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
-def SPIRV_OC_OpTypeStruct                   : I32EnumAttrCase<"OpTypeStruct", 30>;
-def SPIRV_OC_OpTypePointer                  : I32EnumAttrCase<"OpTypePointer", 32>;
-def SPIRV_OC_OpTypeFunction                 : I32EnumAttrCase<"OpTypeFunction", 33>;
-def SPIRV_OC_OpTypeForwardPointer           : I32EnumAttrCase<"OpTypeForwardPointer", 39>;
-def SPIRV_OC_OpConstantTrue                 : I32EnumAttrCase<"OpConstantTrue", 41>;
-def SPIRV_OC_OpConstantFalse                : I32EnumAttrCase<"OpConstantFalse", 42>;
-def SPIRV_OC_OpConstant                     : I32EnumAttrCase<"OpConstant", 43>;
-def SPIRV_OC_OpConstantComposite            : I32EnumAttrCase<"OpConstantComposite", 44>;
-def SPIRV_OC_OpConstantNull                 : I32EnumAttrCase<"OpConstantNull", 46>;
-def SPIRV_OC_OpSpecConstantTrue             : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
-def SPIRV_OC_OpSpecConstantFalse            : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
-def SPIRV_OC_OpSpecConstant                 : I32EnumAttrCase<"OpSpecConstant", 50>;
-def SPIRV_OC_OpSpecConstantComposite        : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
-def SPIRV_OC_OpSpecConstantOp               : I32EnumAttrCase<"OpSpecConstantOp", 52>;
-def SPIRV_OC_OpFunction                     : I32EnumAttrCase<"OpFunction", 54>;
-def SPIRV_OC_OpFunctionParameter            : I32EnumAttrCase<"OpFunctionParameter", 55>;
-def SPIRV_OC_OpFunctionEnd                  : I32EnumAttrCase<"OpFunctionEnd", 56>;
-def SPIRV_OC_OpFunctionCall                 : I32EnumAttrCase<"OpFunctionCall", 57>;
-def SPIRV_OC_OpVariable                     : I32EnumAttrCase<"OpVariable", 59>;
-def SPIRV_OC_OpLoad                         : I32EnumAttrCase<"OpLoad", 61>;
-def SPIRV_OC_OpStore                        : I32EnumAttrCase<"OpStore", 62>;
-def SPIRV_OC_OpCopyMemory                   : I32EnumAttrCase<"OpCopyMemory", 63>;
-def SPIRV_OC_OpAccessChain                  : I32EnumAttrCase<"OpAccessChain", 65>;
-def SPIRV_OC_OpPtrAccessChain               : I32EnumAttrCase<"OpPtrAccessChain", 67>;
-def SPIRV_OC_OpInBoundsPtrAccessChain       : I32EnumAttrCase<"OpInBoundsPtrAccessChain", 70>;
-def SPIRV_OC_OpDecorate                     : I32EnumAttrCase<"OpDecorate", 71>;
-def SPIRV_OC_OpMemberDecorate               : I32EnumAttrCase<"OpMemberDecorate", 72>;
-def SPIRV_OC_OpVectorExtractDynamic         : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
-def SPIRV_OC_OpVectorInsertDynamic          : I32EnumAttrCase<"OpVectorInsertDynamic", 78>;
-def SPIRV_OC_OpVectorShuffle                : I32EnumAttrCase<"OpVectorShuffle", 79>;
-def SPIRV_OC_OpCompositeConstruct           : I32EnumAttrCase<"OpCompositeConstruct", 80>;
-def SPIRV_OC_OpCompositeExtract             : I32EnumAttrCase<"OpCompositeExtract", 81>;
-def SPIRV_OC_OpCompositeInsert              : I32EnumAttrCase<"OpCompositeInsert", 82>;
-def SPIRV_OC_OpTranspose                    : I32EnumAttrCase<"OpTranspose", 84>;
-def SPIRV_OC_OpImageDrefGather              : I32EnumAttrCase<"OpImageDrefGather", 97>;
-def SPIRV_OC_OpImage                        : I32EnumAttrCase<"OpImage", 100>;
-def SPIRV_OC_OpImageQuerySize               : I32EnumAttrCase<"OpImageQuerySize", 104>;
-def SPIRV_OC_OpConvertFToU                  : I32EnumAttrCase<"OpConvertFToU", 109>;
-def SPIRV_OC_OpConvertFToS                  : I32EnumAttrCase<"OpConvertFToS", 110>;
-def SPIRV_OC_OpConvertSToF                  : I32EnumAttrCase<"OpConvertSToF", 111>;
-def SPIRV_OC_OpConvertUToF                  : I32EnumAttrCase<"OpConvertUToF", 112>;
-def SPIRV_OC_OpUConvert                     : I32EnumAttrCase<"OpUConvert", 113>;
-def SPIRV_OC_OpSConvert                     : I32EnumAttrCase<"OpSConvert", 114>;
-def SPIRV_OC_OpFConvert                     : I32EnumAttrCase<"OpFConvert", 115>;
-def SPIRV_OC_OpConvertPtrToU                : I32EnumAttrCase<"OpConvertPtrToU", 117>;
-def SPIRV_OC_OpConvertUToPtr                : I32EnumAttrCase<"OpConvertUToPtr", 120>;
-def SPIRV_OC_OpPtrCastToGeneric             : I32EnumAttrCase<"OpPtrCastToGeneric", 121>;
-def SPIRV_OC_OpGenericCastToPtr             : I32EnumAttrCase<"OpGenericCastToPtr", 122>;
-def SPIRV_OC_OpGenericCastToPtrExplicit     : I32EnumAttrCase<"OpGenericCastToPtrExplicit", 123>;
-def SPIRV_OC_OpBitcast                      : I32EnumAttrCase<"OpBitcast", 124>;
-def SPIRV_OC_OpSNegate                      : I32EnumAttrCase<"OpSNegate", 126>;
-def SPIRV_OC_OpFNegate                      : I32EnumAttrCase<"OpFNegate", 127>;
-def SPIRV_OC_OpIAdd                         : I32EnumAttrCase<"OpIAdd", 128>;
-def SPIRV_OC_OpFAdd                         : I32EnumAttrCase<"OpFAdd", 129>;
-def SPIRV_OC_OpISub                         : I32EnumAttrCase<"OpISub", 130>;
-def SPIRV_OC_OpFSub                         : I32EnumAttrCase<"OpFSub", 131>;
-def SPIRV_OC_OpIMul                         : I32EnumAttrCase<"OpIMul", 132>;
-def SPIRV_OC_OpFMul                         : I32EnumAttrCase<"OpFMul", 133>;
-def SPIRV_OC_OpUDiv                         : I32EnumAttrCase<"OpUDiv", 134>;
-def SPIRV_OC_OpSDiv                         : I32EnumAttrCase<"OpSDiv", 135>;
-def SPIRV_OC_OpFDiv                         : I32EnumAttrCase<"OpFDiv", 136>;
-def SPIRV_OC_OpUMod                         : I32EnumAttrCase<"OpUMod", 137>;
-def SPIRV_OC_OpSRem                         : I32EnumAttrCase<"OpSRem", 138>;
-def SPIRV_OC_OpSMod                         : I32EnumAttrCase<"OpSMod", 139>;
-def SPIRV_OC_OpFRem                         : I32EnumAttrCase<"OpFRem", 140>;
-def SPIRV_OC_OpFMod                         : I32EnumAttrCase<"OpFMod", 141>;
-def SPIRV_OC_OpVectorTimesScalar            : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
-def SPIRV_OC_OpMatrixTimesScalar            : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
-def SPIRV_OC_OpVectorTimesMatrix            : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
-def SPIRV_OC_OpMatrixTimesVector            : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
-def SPIRV_OC_OpMatrixTimesMatrix            : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
-def SPIRV_OC_OpDot                          : I32EnumAttrCase<"OpDot", 148>;
-def SPIRV_OC_OpIAddCarry                    : I32EnumAttrCase<"OpIAddCarry", 149>;
-def SPIRV_OC_OpISubBorrow                   : I32EnumAttrCase<"OpISubBorrow", 150>;
-def SPIRV_OC_OpUMulExtended                 : I32EnumAttrCase<"OpUMulExtended", 151>;
-def SPIRV_OC_OpSMulExtended                 : I32EnumAttrCase<"OpSMulExtended", 152>;
-def SPIRV_OC_OpIsNan                        : I32EnumAttrCase<"OpIsNan", 156>;
-def SPIRV_OC_OpIsInf                        : I32EnumAttrCase<"OpIsInf", 157>;
-def SPIRV_OC_OpOrdered                      : I32EnumAttrCase<"OpOrdered", 162>;
-def SPIRV_OC_OpUnordered                    : I32EnumAttrCase<"OpUnordered", 163>;
-def SPIRV_OC_OpLogicalEqual                 : I32EnumAttrCase<"OpLogicalEqual", 164>;
-def SPIRV_OC_OpLogicalNotEqual              : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
-def SPIRV_OC_OpLogicalOr                    : I32EnumAttrCase<"OpLogicalOr", 166>;
-def SPIRV_OC_OpLogicalAnd                   : I32EnumAttrCase<"OpLogicalAnd", 167>;
-def SPIRV_OC_OpLogicalNot                   : I32EnumAttrCase<"OpLogicalNot", 168>;
-def SPIRV_OC_OpSelect                       : I32EnumAttrCase<"OpSelect", 169>;
-def SPIRV_OC_OpIEqual                       : I32EnumAttrCase<"OpIEqual", 170>;
-def SPIRV_OC_OpINotEqual                    : I32EnumAttrCase<"OpINotEqual", 171>;
-def SPIRV_OC_OpUGreaterThan                 : I32EnumAttrCase<"OpUGreaterThan", 172>;
-def SPIRV_OC_OpSGreaterThan                 : I32EnumAttrCase<"OpSGreaterThan", 173>;
-def SPIRV_OC_OpUGreaterThanEqual            : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
-def SPIRV_OC_OpSGreaterThanEqual            : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
-def SPIRV_OC_OpULessThan                    : I32EnumAttrCase<"OpULessThan", 176>;
-def SPIRV_OC_OpSLessThan                    : I32EnumAttrCase<"OpSLessThan", 177>;
-def SPIRV_OC_OpULessThanEqual               : I32EnumAttrCase<"OpULessThanEqual", 178>;
-def SPIRV_OC_OpSLessThanEqual               : I32EnumAttrCase<"OpSLessThanEqual", 179>;
-def SPIRV_OC_OpFOrdEqual                    : I32EnumAttrCase<"OpFOrdEqual", 180>;
-def SPIRV_OC_OpFUnordEqual                  : I32EnumAttrCase<"OpFUnordEqual", 181>;
-def SPIRV_OC_OpFOrdNotEqual                 : I32EnumAttrCase<"OpFOrdNotEqual", 182>;
-def SPIRV_OC_OpFUnordNotEqual               : I32EnumAttrCase<"OpFUnordNotEqual", 183>;
-def SPIRV_OC_OpFOrdLessThan                 : I32EnumAttrCase<"OpFOrdLessThan", 184>;
-def SPIRV_OC_OpFUnordLessThan               : I32EnumAttrCase<"OpFUnordLessThan", 185>;
-def SPIRV_OC_OpFOrdGreaterThan              : I32EnumAttrCase<"OpFOrdGreaterThan", 186>;
-def SPIRV_OC_OpFUnordGreaterThan            : I32EnumAttrCase<"OpFUnordGreaterThan", 187>;
-def SPIRV_OC_OpFOrdLessThanEqual            : I32EnumAttrCase<"OpFOrdLessThanEqual", 188>;
-def SPIRV_OC_OpFUnordLessThanEqual          : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
-def SPIRV_OC_OpFOrdGreaterThanEqual         : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
-def SPIRV_OC_OpFUnordGreaterThanEqual       : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
-def SPIRV_OC_OpShiftRightLogical            : I32EnumAttrCase<"OpShiftRightLogical", 194>;
-def SPIRV_OC_OpShiftRightArithmetic         : I32EnumAttrCase<"OpShiftRightArithmetic", 195>;
-def SPIRV_OC_OpShiftLeftLogical             : I32EnumAttrCase<"OpShiftLeftLogical", 196>;
-def SPIRV_OC_OpBitwiseOr                    : I32EnumAttrCase<"OpBitwiseOr", 197>;
-def SPIRV_OC_OpBitwiseXor                   : I32EnumAttrCase<"OpBitwiseXor", 198>;
-def SPIRV_OC_OpBitwiseAnd                   : I32EnumAttrCase<"OpBitwiseAnd", 199>;
-def SPIRV_OC_OpNot                          : I32EnumAttrCase<"OpNot", 200>;
-def SPIRV_OC_OpBitFieldInsert               : I32EnumAttrCase<"OpBitFieldInsert", 201>;
-def SPIRV_OC_OpBitFieldSExtract             : I32EnumAttrCase<"OpBitFieldSExtract", 202>;
-def SPIRV_OC_OpBitFieldUExtract             : I32EnumAttrCase<"OpBitFieldUExtract", 203>;
-def SPIRV_OC_OpBitReverse                   : I32EnumAttrCase<"OpBitReverse", 204>;
-def SPIRV_OC_OpBitCount                     : I32EnumAttrCase<"OpBitCount", 205>;
-def SPIRV_OC_OpEmitVertex                   : I32EnumAttrCase<"OpEmitVertex", 218>;
-def SPIRV_OC_OpEndPrimitive                 : I32EnumAttrCase<"OpEndPrimitive", 219>;
-def SPIRV_OC_OpControlBarrier               : I32EnumAttrCase<"OpControlBarrier", 224>;
-def SPIRV_OC_OpMemoryBarrier                : I32EnumAttrCase<"OpMemoryBarrier", 225>;
-def SPIRV_OC_OpAtomicExchange               : I32EnumAttrCase<"OpAtomicExchange", 229>;
-def SPIRV_OC_OpAtomicCompareExchange        : I32EnumAttrCase<"OpAtomicCompareExchange", 230>;
-def SPIRV_OC_OpAtomicCompareExchangeWeak    : I32EnumAttrCase<"OpAtomicCompareExchangeWeak", 231>;
-def SPIRV_OC_OpAtomicIIncrement             : I32EnumAttrCase<"OpAtomicIIncrement", 232>;
-def SPIRV_OC_OpAtomicIDecrement             : I32EnumAttrCase<"OpAtomicIDecrement", 233>;
-def SPIRV_OC_OpAtomicIAdd                   : I32EnumAttrCase<"OpAtomicIAdd", 234>;
-def SPIRV_OC_OpAtomicISub                   : I32EnumAttrCase<"OpAtomicISub", 235>;
-def SPIRV_OC_OpAtomicSMin                   : I32EnumAttrCase<"OpAtomicSMin", 236>;
-def SPIRV_OC_OpAtomicUMin                   : I32EnumAttrCase<"OpAtomicUMin", 237>;
-def SPIRV_OC_OpAtomicSMax                   : I32EnumAttrCase<"OpAtomicSMax", 238>;
-def SPIRV_OC_OpAtomicUMax                   : I32EnumAttrCase<"OpAtomicUMax", 239>;
-def SPIRV_OC_OpAtomicAnd                    : I32EnumAttrCase<"OpAtomicAnd", 240>;
-def SPIRV_OC_OpAtomicOr                     : I32EnumAttrCase<"OpAtomicOr", 241>;
-def SPIRV_OC_OpAtomicXor                    : I32EnumAttrCase<"OpAtomicXor", 242>;
-def SPIRV_OC_OpPhi                          : I32EnumAttrCase<"OpPhi", 245>;
-def SPIRV_OC_OpLoopMerge                    : I32EnumAttrCase<"OpLoopMerge", 246>;
-def SPIRV_OC_OpSelectionMerge               : I32EnumAttrCase<"OpSelectionMerge", 247>;
-def SPIRV_OC_OpLabel                        : I32EnumAttrCase<"OpLabel", 248>;
-def SPIRV_OC_OpBranch                       : I32EnumAttrCase<"OpBranch", 249>;
-def SPIRV_OC_OpBranchConditional            : I32EnumAttrCase<"OpBranchConditional", 250>;
-def SPIRV_OC_OpReturn                       : I32EnumAttrCase<"OpReturn", 253>;
-def SPIRV_OC_OpReturnValue                  : I32EnumAttrCase<"OpReturnValue", 254>;
-def SPIRV_OC_OpUnreachable                  : I32EnumAttrCase<"OpUnreachable", 255>;
-def SPIRV_OC_OpGroupBroadcast               : I32EnumAttrCase<"OpGroupBroadcast", 263>;
-def SPIRV_OC_OpGroupIAdd                    : I32EnumAttrCase<"OpGroupIAdd", 264>;
-def SPIRV_OC_OpGroupFAdd                    : I32EnumAttrCase<"OpGroupFAdd", 265>;
-def SPIRV_OC_OpGroupFMin                    : I32EnumAttrCase<"OpGroupFMin", 266>;
-def SPIRV_OC_OpGroupUMin                    : I32EnumAttrCase<"OpGroupUMin", 267>;
-def SPIRV_OC_OpGroupSMin                    : I32EnumAttrCase<"OpGroupSMin", 268>;
-def SPIRV_OC_OpGroupFMax                    : I32EnumAttrCase<"OpGroupFMax", 269>;
-def SPIRV_OC_OpGroupUMax                    : I32EnumAttrCase<"OpGroupUMax", 270>;
-def SPIRV_OC_OpGroupSMax                    : I32EnumAttrCase<"OpGroupSMax", 271>;
-def SPIRV_OC_OpNoLine                       : I32EnumAttrCase<"OpNoLine", 317>;
-def SPIRV_OC_OpModuleProcessed              : I32EnumAttrCase<"OpModuleProcessed", 330>;
-def SPIRV_OC_OpGroupNonUniformElect         : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
-def SPIRV_OC_OpGroupNonUniformBroadcast     : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
-def SPIRV_OC_OpGroupNonUniformBallot        : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
-def SPIRV_OC_OpGroupNonUniformBallotFindLSB : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>;
-def SPIRV_OC_OpGroupNonUniformBallotFindMSB : I32EnumAttrCase<"OpGroupNonUniformBallotFindMSB", 344>;
-def SPIRV_OC_OpGroupNonUniformShuffle       : I32EnumAttrCase<"OpGroupNonUniformShuffle", 345>;
-def SPIRV_OC_OpGroupNonUniformShuffleXor    : I32EnumAttrCase<"OpGroupNonUniformShuffleXor", 346>;
-def SPIRV_OC_OpGroupNonUniformShuffleUp     : I32EnumAttrCase<"OpGroupNonUniformShuffleUp", 347>;
-def SPIRV_OC_OpGroupNonUniformShuffleDown   : I32EnumAttrCase<"OpGroupNonUniformShuffleDown", 348>;
-def SPIRV_OC_OpGroupNonUniformIAdd          : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>;
-def SPIRV_OC_OpGroupNonUniformFAdd          : I32EnumAttrCase<"OpGroupNonUniformFAdd", 350>;
-def SPIRV_OC_OpGroupNonUniformIMul          : I32EnumAttrCase<"OpGroupNonUniformIMul", 351>;
-def SPIRV_OC_OpGroupNonUniformFMul          : I32EnumAttrCase<"OpGroupNonUniformFMul", 352>;
-def SPIRV_OC_OpGroupNonUniformSMin          : I32EnumAttrCase<"OpGroupNonUniformSMin", 353>;
-def SPIRV_OC_OpGroupNonUniformUMin          : I32EnumAttrCase<"OpGroupNonUniformUMin", 354>;
-def SPIRV_OC_OpGroupNonUniformFMin          : I32EnumAttrCase<"OpGroupNonUniformFMin", 355>;
-def SPIRV_OC_OpGroupNonUniformSMax          : I32EnumAttrCase<"OpGroupNonUniformSMax", 356>;
-def SPIRV_OC_OpGroupNonUniformUMax          : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
-def SPIRV_OC_OpGroupNonUniformFMax          : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
-def SPIRV_OC_OpGroupNonUniformBitwiseAnd    : I32EnumAttrCase<"OpGroupNonUniformBitwiseAnd", 359>;
-def SPIRV_OC_OpGroupNonUniformBitwiseOr     : I32EnumAttrCase<"OpGroupNonUniformBitwiseOr", 360>;
-def SPIRV_OC_OpGroupNonUniformBitwiseXor    : I32EnumAttrCase<"OpGroupNonUniformBitwiseXor", 361>;
-def SPIRV_OC_OpGroupNonUniformLogicalAnd    : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
-def SPIRV_OC_OpGroupNonUniformLogicalOr     : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
-def SPIRV_OC_OpGroupNonUniformLogicalXor    : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
-def SPIRV_OC_OpSubgroupBallotKHR            : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
-def SPIRV_OC_OpSDot                         : I32EnumAttrCase<"OpSDot", 4450>;
-def SPIRV_OC_OpUDot                         : I32EnumAttrCase<"OpUDot", 4451>;
-def SPIRV_OC_OpSUDot                        : I32EnumAttrCase<"OpSUDot", 4452>;
-def SPIRV_OC_OpSDotAccSat                   : I32EnumAttrCase<"OpSDotAccSat", 4453>;
-def SPIRV_OC_OpUDotAccSat                   : I32EnumAttrCase<"OpUDotAccSat", 4454>;
-def SPIRV_OC_OpSUDotAccSat                  : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
-def SPIRV_OC_OpTypeCooperativeMatrixKHR     : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
-def SPIRV_OC_OpCooperativeMatrixLoadKHR     : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
-def SPIRV_OC_OpCooperativeMatrixStoreKHR    : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
-def SPIRV_OC_OpCooperativeMatrixMulAddKHR   : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
-def SPIRV_OC_OpCooperativeMatrixLengthKHR   : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
-def SPIRV_OC_OpSubgroupBlockReadINTEL       : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
-def SPIRV_OC_OpSubgroupBlockWriteINTEL      : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
-def SPIRV_OC_OpAssumeTrueKHR                : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
-def SPIRV_OC_OpAtomicFAddEXT                : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
-def SPIRV_OC_OpConvertFToBF16INTEL          : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
-def SPIRV_OC_OpConvertBF16ToFINTEL          : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
-def SPIRV_OC_OpControlBarrierArriveINTEL    : I32EnumAttrCase<"OpControlBarrierArriveINTEL", 6142>;
-def SPIRV_OC_OpControlBarrierWaitINTEL      : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
-def SPIRV_OC...
[truncated]

@IgWod-IMG
Copy link
Contributor Author

Also, should SPIRV_AttrIs be moved to something more top-level, .e.g., SPIRVBase.td? I put it in the SPIRVNonUniformOps.td, as it is currently only used there.


func.func @group_non_uniform_ballot_bit_count(%value: vector<4xi32>) -> i32 {
// CHECK: {{%.*}} = spirv.GroupNonUniformBallotBitCount <Subgroup> <Reduce> {{%.*}} : vector<4xi32> -> i32
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <Reduce> %value : vector<4xi32> -> i32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add test cases with vector<3xi32> and vector<4xi8>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do it!

@@ -103,5 +103,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.KHR.GroupFMul <Workgroup> <Reduce> %value : f32
spirv.ReturnValue %0: f32
}

// CHECK-LABEL: @group_non_uniform_ballot_bit_count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this test in both files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean here and in mlir/test/Dialect/SPIRV/IR/group-ops.mlir? It seems to be the common practise to have the same test in Target and in Dialect. This test is a bit different as it uses spirv.func, spirv.Return and it is wrapped in a module. Although probably, I should have it in a separate module that defines a correct capability. N.b. other tests currently in the file aren't completely correct as the module only defines Shader capability. I can correct the whole file or just update my test. Up to you.

Comment on lines 17 to 22
class SPIRV_AttrIs<string operand, string type, string value> : PredOpTrait<
operand # " must be " # type # " of value " # value,
CPred<"::llvm::cast<::mlir::spirv::" # type # "Attr>(getProperties()." # operand # ").getValue() == ::mlir::spirv::" # type # "::" # value>
>;

class SPIRV_GroupOperationAttrIs<string operand, string value> : SPIRV_AttrIs<operand, "GroupOperation", value>;
class SPIRV_ExecutionScopeAttrIs<string operand, string value> : SPIRV_AttrIs<operand, "Scope", value>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

@kuhar kuhar requested a review from andfau-amd February 6, 2025 16:08
@@ -1287,4 +1297,70 @@ def SPIRV_GroupNonUniformLogicalXorOp :

// -----

def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCount", [
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">,
SPIRV_GroupOperationAttrIs<"group_operation", "Reduce">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the requirement that the group operation must be Reduce come from? The requirement that the execution scope is Subgroup seems to come from the SPIR-V op definition, but this doesn't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes from me misreading the spec. I interpreted:

The identity I for Operation is 0.

As: "Operation must be 0", where 0 happens to be Reduce. Thanks for pointing this out, my mistake.

@IgWod-IMG IgWod-IMG force-pushed the img_group-non-uniform-ballot-bit-count branch from 8df0943 to 54fbbf6 Compare February 6, 2025 18:08
@IgWod-IMG
Copy link
Contributor Author

I had to re-push the updated patch, after GitHub got stuck processing it and didn't display an update for over an hour, but it should be all good now. I have constrained $value to be 32-bits and kept SPIRV_SignlessOrUnsignedInt as discussed. I added more tests (2 requested + few more around types), and of course corrected the Reduce mishap. Also, I wrapped my op into more suitable module in mlir/test/Target/SPIRV/group-ops.mlir. Please let me know if there is anything else I should change, and if not please merge it.

@@ -1287,4 +1296,69 @@ def SPIRV_GroupNonUniformLogicalXorOp :

// -----

def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCount", [
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should a bunch of other operations also have this predicate? It seems like this "must be Subgroup" language applies to many other operations. I'm not sure if that's something this PR should address though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, more we can verify in ODS, the better; however in my opinion this should be a separate PR (which it looks like we agree on), so I am not going to change anything in this ticket.

@@ -4262,6 +4263,7 @@ def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
class SPIRV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
def SPIRV_IntVec4 : SPIRV_Vec4<SPIRV_Integer>;
def SPIRV_IOrUIVec4 : SPIRV_Vec4<SPIRV_SignlessOrUnsignedInt>;
def SPIRV_IOrUI32Vec4 : SPIRV_Vec4<SPIRV_SignlessOrUnsignedInt32>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rare enough that I'd use SPIRV_Vec4<SignlessOrUnsignedIntOfWidths<[32]>> directly where you need it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I have just pushed an updated patch, so it's ready to be merged if there are no more comments.

@IgWod-IMG IgWod-IMG force-pushed the img_group-non-uniform-ballot-bit-count branch from 54fbbf6 to 225a6ca Compare February 6, 2025 21:15
@IgWod-IMG
Copy link
Contributor Author

Could you please commit it? I'm still waiting for the committer privilege to be granted.

@andfau-amd andfau-amd merged commit 1454fc9 into llvm:main Feb 7, 2025
8 checks passed
@IgWod-IMG IgWod-IMG deleted the img_group-non-uniform-ballot-bit-count branch February 7, 2025 13:44
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…m#126055)

A new constraint is also added to restrict attributes values for SPIR-V
attributes. Ideally this should use `ConfinedAttr` with a custom
constraint directly on the operand, however it seems TableGen does not
allow using that with SPIR-V attributes. I suspect it is because SPIR-V
attributes do not derive from the generic MLIR attribute class -
TableGen complains about missing enum field.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants