diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 5d36067a7629..032040d6368f 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -341,3 +341,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return } } + +// ----- + + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_ld_128x256 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.wait::st.sync.aligned + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.wait::ld.sync.aligned + tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked1> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked1>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked1> + tt.return + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 4436376efd6d..48ae6b88c973 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -12,9 +12,133 @@ using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::gpu; -static const int largestTmemLoadStore = 128; +// The maximum number of tensor memory registers that can be accessed +// by a single message regardless of shape or repetitions +static constexpr int largestTmemLoadStore = 128; +// The maximum number of thread registers that can be populated by +// multiple messages +static constexpr int maxRegisters = 256; + namespace { +struct TMemAccessAtom { + int opBitWidth; + int colsPerThread; + int rowsPerThread; + const char *opShape; + bool usesSecondHalfOffset; +}; + +constexpr TMemAccessAtom TMemAccess32x32b{.opBitWidth = 32, + .colsPerThread = 1, + .rowsPerThread = 1, + .opShape = "32x32b", + .usesSecondHalfOffset = false}; + +constexpr TMemAccessAtom TMemAccess16x32bx2{.opBitWidth = 32, + .colsPerThread = 1, + .rowsPerThread = 1, + .opShape = "16x32bx2", + .usesSecondHalfOffset = true}; + +constexpr TMemAccessAtom TMemAccess16x256b{.opBitWidth = 256, + .colsPerThread = 2, + .rowsPerThread = 2, + .opShape = "16x256b", + .usesSecondHalfOffset = false}; + +struct TMemMessageTraits { + TMemAccessAtom atom; + bool usesSecondHalfOffset; + int numThreadsPerWarp; + int maxNumRepeats; + int maxCols; + int numRows; + int numCols; + int numRepeats; + int numRegs; + + bool operator<(const TMemMessageTraits &other) const { + return numRegs < other.numRegs; + } +}; + +struct TMemRuntimeInfo { + static constexpr int numRowsPerWarp = 32; + int numWarps; + int numWarpGroups; + int numElementsPer32B; + int numElements; + int numCols; + int blockM; + int blockN; + bool unpackedb16; + bool useStridedMessage; + int numBlocks; + int numWarpGroupsPerBlock; + bool blocksInterleaved; + int numColsPerBlock; + int colsPerWarpGroup; +}; + +TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom, + int narrowingFactor) { + TMemMessageTraits m; + m.atom = atom; + m.usesSecondHalfOffset = atom.usesSecondHalfOffset; + m.numThreadsPerWarp = 32; + m.maxNumRepeats = + largestTmemLoadStore / (atom.colsPerThread * atom.rowsPerThread); + m.maxCols = (atom.opBitWidth / 32) * m.maxNumRepeats; + m.numRows = m.numThreadsPerWarp / atom.rowsPerThread; + m.numCols = m.maxCols / narrowingFactor; + m.numRepeats = m.numCols / (atom.opBitWidth / 32); + m.numRegs = atom.colsPerThread * atom.rowsPerThread * m.numRepeats; + return m; +} + +// Only allows half of the thread registers to be used for tensor memory access +// to avoid register pressure. This ensures the largest tmem message width is +// used for the workload without inducing spills. +int getTMemMessageNarrowingFactor(int workloadThreadRegs) { + const int allowedRegUsage = maxRegisters / 2; + int narrowingFactor = 1; + while (workloadThreadRegs > allowedRegUsage) { + workloadThreadRegs /= 2; + narrowingFactor *= 2; + } + return narrowingFactor; +} + +int getEffectiveRegs(bool unpackedb16, bool useStridedMessage, int numRegs) { + // The effective register count is less when using unpacked or strided + // messages + if (unpackedb16) { + numRegs /= 2; + } + if (useStridedMessage) { + numRegs /= 2; + } + return numRegs; +} + +// If the workload runtime requires fewer registers than the default message +// width, use the widest possible message that matches the workload +TMemMessageTraits constrainMessageFromWorkload(TMemMessageTraits m, + const TMemRuntimeInfo &info, + int numRegs) { + m.numRegs = + getEffectiveRegs(info.unpackedb16, info.useStridedMessage, numRegs); + m.numRegs = std::min(largestTmemLoadStore, m.numRegs); + // Invert the above formulas to calculate the effective runtime message width + m.numCols = (m.numRegs * (m.atom.opBitWidth / 32)) / + (m.atom.colsPerThread * m.atom.rowsPerThread); + // Half as many registers are needed for 16-bit packed elements, + // so twice as many columns are accessed per message. + m.numCols *= info.numElementsPer32B; + return m; +} + SmallVector packToI32(const SmallVector &values, Location loc, ConversionPatternRewriter &rewriter) { auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -38,143 +162,137 @@ SmallVector packToI32(const SmallVector &values, Location loc, return packedValues; } -// Helper to compute how many registers are needed to load/store `numCols` Tmem -// coloumns. -int getNum32BRegs(int rowsPerMessage, bool unpackedb16, int numElementsPer32B, - int numCols) { - int numRegPerMessage = numCols; - if (rowsPerMessage == 16) - numRegPerMessage = numRegPerMessage / 2; - if (unpackedb16) - numRegPerMessage = numRegPerMessage / numElementsPer32B; - return numRegPerMessage; -} - -// Map the distributed layout onto the tmem. Calculate the address and emit one -// or more tmem messages. -void calculateAddressAndEmitTmemMessage( - Location loc, Operation *op, Value baseAddress, RankedTensorType tensorType, - MemDescType memType, ConversionPatternRewriter &rewriter, - const std::function &createMemoryOp) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - const int numRowsPerWarp = 32; - +TMemRuntimeInfo getTMemRuntimeInfo(Operation *op, RankedTensorType tensorType, + MemDescType memType) { + TMemRuntimeInfo info; + static_assert(TMemRuntimeInfo::numRowsPerWarp == 32, + "A single warp must access exactly 32 rows of tmem"); assert( nvidia_gpu::isDistributedLayoutTMemCompatible(op, tensorType, memType) && "unsupported distributed layout for tensor memory"); - int numWarps = triton::gpu::lookupNumWarps(op); - assert(numWarps % 4 == 0); - int numWarpGroups = numWarps / 4; - int numElementsPer32B = 32 / tensorType.getElementTypeBitWidth(); + info.numWarps = triton::gpu::lookupNumWarps(op); + assert(info.numWarps % 4 == 0 && "Unexpected number of warps"); + info.numWarpGroups = info.numWarps / 4; + info.numElementsPer32B = 32 / tensorType.getElementTypeBitWidth(); auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType); - int numElements = product(shapePerCTA); + info.numElements = product(shapePerCTA); + triton::nvidia_gpu::TMemAllocation tmemAlloc = triton::nvidia_gpu::getTmemAllocSizes(memType); - int numCols = tmemAlloc.numCols; - int blockM = 0; - int blockN = 0; - bool unpackedb16 = false; + info.numCols = tmemAlloc.numCols; + + info.blockM = 0; + info.blockN = 0; + info.unpackedb16 = false; if (auto attr = dyn_cast( memType.getEncoding())) { - blockM = attr.getBlockM(); - blockN = attr.getBlockN(); - assert((!attr.getUnpacked() || numElementsPer32B <= 2) && + info.blockM = attr.getBlockM(); + info.blockN = attr.getBlockN(); + assert((!attr.getUnpacked() || info.numElementsPer32B <= 2) && "unsupported unpacked layout"); - unpackedb16 = attr.getUnpacked() && numElementsPer32B == 2; + info.unpackedb16 = attr.getUnpacked() && (info.numElementsPer32B == 2); } else { assert(isa( memType.getEncoding()) && "Expecting a tensor memory encoding attribute"); - blockM = 128; - blockN = 32; + info.blockM = 128; + info.blockN = 32; } - int numBlocks = ceil(numElements, blockM * blockN); - bool useStridedMessage = blockM == 64; - int numWarpGroupsPerBlock = ceil(numWarpGroups, numBlocks); + info.useStridedMessage = (info.blockM == 64); - bool blocksInterleaved = numBlocks > 1 && blockM == 64; - int numColsPerBlock = numCols / numBlocks; - if (blocksInterleaved) - // We pack two blocks in the same column group of tmem. - numColsPerBlock *= 2; + info.numBlocks = ceil(info.numElements, info.blockM * info.blockN); + info.numWarpGroupsPerBlock = ceil(info.numWarpGroups, info.numBlocks); + info.blocksInterleaved = (info.numBlocks > 1 && info.useStridedMessage); + info.numColsPerBlock = info.numCols / info.numBlocks; + if (info.blocksInterleaved) { + info.numColsPerBlock *= 2; + } + info.colsPerWarpGroup = info.numColsPerBlock / info.numWarpGroupsPerBlock; + // If more than one warp group processes the same block, + // then fewer columns must be processed per message per warp group + info.numColsPerBlock /= info.numWarpGroupsPerBlock; + return info; +} + +void calculateAddressAndEmitTmemMessage( + Location loc, Value baseAddress, const TMemRuntimeInfo &info, + const TMemMessageTraits &message, ConversionPatternRewriter &rewriter, + const std::function &createMemoryOp) { + TritonLLVMOpBuilder b(loc, rewriter); Value warpId = rewriter.create(loc); Value warpIdInGroup = b.urem(warpId, b.i32_val(4)); Value warpGroupId = b.udiv(warpId, b.i32_val(4)); - Value rowId = b.mul(warpIdInGroup, b.i32_val(numRowsPerWarp)); - - int colsPerWarpGroup = numColsPerBlock / numWarpGroupsPerBlock; - - int rowsPerMessage = blockM == 64 ? 16 : 32; - int numRegs = getNum32BRegs(rowsPerMessage, unpackedb16, numElementsPer32B, - colsPerWarpGroup); - - int numRegsPerMessage = std::min(largestTmemLoadStore, numRegs); - int numColsPerMessage = (numColsPerBlock * numRegsPerMessage) / numRegs; - if (blockM == 64 && numColsPerMessage > blockN) { - numColsPerMessage = blockN; - numRegsPerMessage = (numColsPerMessage * numRegs) / numColsPerBlock; - } - for (int block = 0; block < numBlocks; block += numWarpGroups) { + for (int block = 0; block < info.numBlocks; block += info.numWarpGroups) { Value address = b.ptrtoint(i32_ty, baseAddress); Value blockId = b.add(b.i32_val(block), - b.udiv(warpGroupId, b.i32_val(numWarpGroupsPerBlock))); - Value blockRowId = rowId; + b.udiv(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock))); Value warpGroupIdInBlock = - b.urem(warpGroupId, b.i32_val(numWarpGroupsPerBlock)); + b.urem(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock)); Value startColumnId = - b.mul(warpGroupIdInBlock, b.i32_val(colsPerWarpGroup)); - if (blocksInterleaved) { + b.mul(warpGroupIdInBlock, b.i32_val(info.colsPerWarpGroup)); + Value blockRowId = + b.mul(warpIdInGroup, b.i32_val(TMemRuntimeInfo::numRowsPerWarp)); + + if (info.blocksInterleaved) { Value blockIdIsOdd = b.urem(blockId, b.i32_val(2)); Value blockIdPrevEven = b.sub(blockId, blockIdIsOdd); blockRowId = b.add(blockRowId, b.mul(blockIdIsOdd, b.i32_val(16))); startColumnId = b.add(startColumnId, - b.mul(blockIdPrevEven, b.i32_val(numColsPerBlock / 2))); + b.mul(blockIdPrevEven, b.i32_val(info.numColsPerBlock / 2))); } else { startColumnId = - b.add(startColumnId, b.mul(blockId, b.i32_val(numColsPerBlock))); + b.add(startColumnId, b.mul(blockId, b.i32_val(info.numColsPerBlock))); } - address = - b.add(b.add(address, b.shl(blockRowId, b.i32_val(16))), startColumnId); - - for (int colStart = 0; colStart < numColsPerBlock; - colStart += numColsPerMessage) { - Value startAddress = b.add(address, b.i32_val(colStart)); - - // Column offset of second half of the message in case of 16x32bx2 - // message. - int secondHalfColOffset = useStridedMessage ? colsPerWarpGroup / 2 : 0; - createMemoryOp(startAddress, secondHalfColOffset, unpackedb16, - numRegsPerMessage, useStridedMessage); + + // A strided message accesses twice as many columns per message, + // thus half as many messages are required + int numColumns = info.useStridedMessage ? info.numColsPerBlock / 2 + : info.numColsPerBlock; + for (int colStart = 0; colStart < numColumns; colStart += message.numCols) { + // For messages that span only 16 rows (e.g. 16x256b), multiple messages + // are required to cover the entire set of rows per warp. + for (int rowStart = 0; rowStart < TMemRuntimeInfo::numRowsPerWarp; + rowStart += message.numRows) { + Value rowOffset = b.add(blockRowId, b.i32_val(rowStart)); + Value warpGroupAddress = + b.add(address, b.shl(rowOffset, b.i32_val(16))); + warpGroupAddress = b.add(warpGroupAddress, startColumnId); + + Value msgAddress = b.add(warpGroupAddress, b.i32_val(colStart)); + int secondHalfColOffset = 0; + if (info.useStridedMessage) { + // Offset to half way through the set of columns for this warpgroup. + secondHalfColOffset = numColumns; + } + createMemoryOp(msgAddress, secondHalfColOffset, info.unpackedb16, + message.numRegs, info.useStridedMessage); + } } } } -static void createTensorMemoryStore(Location loc, Value address, - SmallVector &srcs, - bool stridedMessage, int secondHalfOffset, - Value pred, bool unpacked, - ConversionPatternRewriter &rewriter) { +void createTensorMemoryStore(Location loc, Value address, + SmallVector &srcs, int secondHalfOffset, + Value pred, bool unpacked, + const TMemAccessAtom &atom, + ConversionPatternRewriter &rewriter) { PTXBuilder ptxBuilder; - std::string opcode; std::string packedStr = unpacked ? ".unpack::16b" : ""; - if (stridedMessage) { - opcode = "@$0 tcgen05.st.sync.aligned.16x32bx2.x" + - std::to_string(srcs.size()) + packedStr + ".b32 [$1], " + - std::to_string(secondHalfOffset) + ", {"; + unsigned numRepeats = srcs.size() / (atom.rowsPerThread * atom.colsPerThread); + std::string opcode = "@$0 tcgen05.st.sync.aligned." + + std::string(atom.opShape) + ".x" + + std::to_string(numRepeats) + packedStr; + if (secondHalfOffset) + opcode += ".b32 [$1], " + std::to_string(secondHalfOffset) + ", {"; + else + opcode += ".b32 [$1], {"; - } else { - opcode = "@$0 tcgen05.st.sync.aligned.32x32b.x" + - std::to_string(srcs.size()) + packedStr + ".b32 [$1], {"; - } SmallVector operands; operands.push_back(ptxBuilder.newOperand(pred, "b")); operands.push_back(ptxBuilder.newOperand(address, "r")); @@ -183,7 +301,7 @@ static void createTensorMemoryStore(Location loc, Value address, auto *resultOp = ptxBuilder.newOperand(srcs[i], "r"); operands.push_back(resultOp); if (i < srcs.size() - 1) - opcode = opcode + ", "; + opcode += ", "; } opcode += "};"; @@ -219,6 +337,22 @@ static void reorderScales(SmallVector &srcValues, int64_t k) { srcValues = std::move(reorderedValues); } +TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) { + auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b; + + int totalRegsNeeded = + getEffectiveRegs(info.unpackedb16, info.useStridedMessage, info.numCols); + int narrowingFactor = getTMemMessageNarrowingFactor(totalRegsNeeded); + auto narrowedMessage = getTMemMessageFromAtom(atom, narrowingFactor); + narrowedMessage = constrainMessageFromWorkload(narrowedMessage, info, + narrowedMessage.numRegs); + + auto maxWidthMessage = getTMemMessageFromAtom(atom, /*narrowingFactor=*/1); + maxWidthMessage = constrainMessageFromWorkload(maxWidthMessage, info, + info.colsPerWarpGroup); + return std::min(narrowedMessage, maxWidthMessage); +} + static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src, Value dest, Value llSrc, Value pred, Value tmemBase, @@ -233,21 +367,23 @@ static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src, // along K are stored separately. reorderScales(srcValues, dstType.getShape().back()); } + + auto info = getTMemRuntimeInfo(op, cast(src.getType()), + cast(dest.getType())); + const TMemMessageTraits message = selectTMemMessage(info); int regIdx = 0; calculateAddressAndEmitTmemMessage( - loc, op, tmemBase, cast(src.getType()), - cast(dest.getType()), rewriter, + loc, tmemBase, info, message, rewriter, [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, - int regsPerMessage, bool useStridedMessage) { + int regsPerMsg, bool useStridedMessage) { SmallVector srcValuesSlice(srcValues.begin() + regIdx, srcValues.begin() + regIdx + - regsPerMessage); - regIdx += regsPerMessage; + regsPerMsg); + regIdx += regsPerMsg; createTensorMemoryStore(loc, startAddress, srcValuesSlice, - useStridedMessage, secondHalfColOffset, pred, - unpackedb16, rewriter); + secondHalfColOffset, pred, unpackedb16, + message.atom, rewriter); }); - createWaitOpSt(loc, rewriter); // Emit a barrier to ensure all threads have finished writing to tensor memory @@ -292,35 +428,29 @@ struct TensorMemoryAllocOpConversion } }; -static Value createTensorMemoryLoad(Location loc, - triton::nvidia_gpu::TMEMLoadOp op, - Value address, int secondHalfOffset, - bool unpacked, int numRegPerMessage, - bool stridedMessage, - ConversionPatternRewriter &rewriter) { +Value createTensorMemoryLoad(Location loc, triton::nvidia_gpu::TMEMLoadOp op, + Value address, int secondHalfOffset, bool unpacked, + int numRegPerMessage, const TMemAccessAtom &atom, + ConversionPatternRewriter &rewriter) { PTXBuilder ptxBuilder; - std::string opcode; // If the memory is unpacked we need to pack on the fly when loading. std::string packedStr = unpacked ? ".pack::16b" : ""; - if (stridedMessage) { - opcode = "tcgen05.ld.sync.aligned.16x32bx2.x" + - std::to_string(numRegPerMessage) + packedStr + ".b32 {"; - } else { - opcode = "tcgen05.ld.sync.aligned.32x32b.x" + - std::to_string(numRegPerMessage) + packedStr + ".b32 {"; - } + unsigned numRepeats = + numRegPerMessage / (atom.rowsPerThread * atom.colsPerThread); + std::string opcode = "tcgen05.ld.sync.aligned." + std::string(atom.opShape) + + ".x" + std::to_string(numRepeats) + packedStr + ".b32 {"; + SmallVector operands; for (int i = 0; i < numRegPerMessage; i++) { opcode += "$" + std::to_string(i); auto *resultOp = ptxBuilder.newOperand("=r"); operands.push_back(resultOp); if (i < numRegPerMessage - 1) - opcode = opcode + ", "; + opcode += ", "; } opcode += "}, [$" + std::to_string(numRegPerMessage) + "]"; - if (stridedMessage) { + if (secondHalfOffset) opcode += ", " + std::to_string(secondHalfOffset); - } opcode += ";"; operands.push_back(ptxBuilder.newOperand(address, "r")); auto &ld = *ptxBuilder.create(opcode); @@ -376,20 +506,22 @@ struct TensorMemoryLoadOpConversion getTypeConverter()->convertType(op.getSrc().getType().getElementType()); auto tmemBase = adaptor.getSrc(); + auto info = getTMemRuntimeInfo(op, cast(op.getType()), + cast(op.getSrc().getType())); + const TMemMessageTraits message = selectTMemMessage(info); SmallVector resultVals; calculateAddressAndEmitTmemMessage( - loc, op, tmemBase, op.getType(), op.getSrc().getType(), rewriter, + loc, tmemBase, info, message, rewriter, [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, int regsPerMessage, bool useStridedMessage) { Value packedValues = createTensorMemoryLoad( loc, op, startAddress, secondHalfColOffset, unpackedb16, - regsPerMessage, useStridedMessage, rewriter); + regsPerMessage, message.atom, rewriter); auto results = unpackResults(packedValues, op.getType().getElementType(), regsPerMessage, loc, rewriter); resultVals.append(results.begin(), results.end()); }); - Type structTy = getTypeConverter()->convertType(op.getType()); Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy);