diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ee39725a27cf..4784453f6d08 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -620,10 +620,10 @@ struct ConvertLayoutOpConversion // is implemented SmallVector reorderedVals; for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(vecVals[i]); - reorderedVals.push_back(vecVals[i + 2]); - reorderedVals.push_back(vecVals[i + 1]); - reorderedVals.push_back(vecVals[i + 3]); + reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); } Value view = getTypeConverter()->packLLElements(loc, reorderedVals, @@ -642,19 +642,19 @@ struct ConvertLayoutOpConversion auto loc = op.getLoc(); Value src = op.getSrc(); Value dst = op.getResult(); - bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor()); auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2 + if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 res = SharedToDotOperandMMAv2::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, getTypeConverter(), tid_val()); - } else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1 + } else if (!isOuter && mmaLayout.isVolta() && + supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1 bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow(); auto srcSharedLayout = src.getType() .cast() diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 2bffe4e32807..21ed46ba445a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -19,10 +19,10 @@ using ::mlir::triton::gpu::SharedEncodingAttr; class MMA16816SmemLoader { public: MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, - ArrayRef smemStrides, ArrayRef tileShape, - ArrayRef instrShape, ArrayRef matShape, - int perPhase, int maxPhase, int elemBytes, - ConversionPatternRewriter &rewriter, + int kWidth, ArrayRef smemStrides, + ArrayRef tileShape, ArrayRef instrShape, + ArrayRef matShape, int perPhase, int maxPhase, + int elemBytes, ConversionPatternRewriter &rewriter, TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc); @@ -33,7 +33,7 @@ class MMA16816SmemLoader { if (canUseLdmatrix) return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); else - return computeLdsMatOffs(warpOff, lane, cSwizzleOffset, elemBytes); + return computeLdsMatOffs(warpOff, lane, cSwizzleOffset); return {}; } @@ -45,7 +45,7 @@ class MMA16816SmemLoader { Value cSwizzleOffset); // compute 8-bit matrix offset. SmallVector computeLdsMatOffs(Value warpOff, Value lane, - Value cSwizzleOffset, int elemBytes); + Value cSwizzleOffset); // Load 4 matrices and returns 4 vec<2> elements. std::tuple @@ -55,6 +55,7 @@ class MMA16816SmemLoader { private: SmallVector order; int kOrder; + int kWidth; SmallVector tileShape; SmallVector instrShape; SmallVector matShape; @@ -176,9 +177,7 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane, SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, Value lane, - Value cSwizzleOffset, - int elemBytes) { - assert(elemBytes <= 4); + Value cSwizzleOffset) { int cTileShape = tileShape[order[0]]; int sTileShape = tileShape[order[1]]; if (!needTrans) { @@ -187,10 +186,10 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, SmallVector offs(numPtrs); + int vecWidth = kWidth; int threadsPerQuad[2] = {8, 4}; int laneWidth = 4; int laneHeight = 8; - int vecWidth = 4 / elemBytes; int quadWidth = laneWidth * vecWidth; int quadHeight = laneHeight; int numQuadI = 2; @@ -232,8 +231,8 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, Value i = add(iBase, mul(iOff, i32_val(quadHeight))); Value j = add(jBase, mul(jOff, i32_val(quadWidth))); // wrap around the bounds - i = urem(i, i32_val(cTileShape)); - j = urem(j, i32_val(sTileShape)); + // i = urem(i, i32_val(cTileShape)); + // j = urem(j, i32_val(sTileShape)); if (needTrans) { offs[idx] = add(i, mul(j, sStride)); } else { @@ -304,7 +303,6 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef offs, return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1), extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { - elemTy = matTy.cast().getBody()[0]; // base pointers std::array, 2> ptrs; int vecWidth = 4 / elemBytes; @@ -324,39 +322,50 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef offs, std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) - SmallVector> vals(4, SmallVector(vecWidth)); + SmallVector> vptrs(4, SmallVector(vecWidth)); for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) - vals[i][j] = load(gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2])); + vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]); // row + trans and col + no-trans are equivalent - if ((needTrans && kOrder == 1) || (!needTrans && kOrder == 0)) - std::swap(vals[1], vals[2]); + bool isActualTrans = + (needTrans && kOrder == 1) || (!needTrans && kOrder == 0); + if (isActualTrans) + std::swap(vptrs[1], vptrs[2]); // pack loaded vectors into 4 32-bit values + int inc = needTrans ? 1 : kWidth; + VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc); + int canonBits = std::min(32, 8 * elemBytes * inc); + int canonWidth = (8 * elemBytes * inc) / canonBits; + Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(elemTy)); - for (int m = 0; m < 4; ++m) { - for (int e = 0; e < vecWidth; ++e) - retElems[m] = insert_element(retElems[m].getType(), retElems[m], - vals[m][e], i32_val(e)); + retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + for (int r = 0; r < 2; ++r) { + for (int em = 0; em < 2 * vecWidth; em += inc) { + int e = em % vecWidth; + int m = em / vecWidth; + int idx = m * 2 + r; + Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3)); + Value val = load(ptr); + Value canonval = bitcast(val, vec_ty(canonInt, canonWidth)); + for (int w = 0; w < canonWidth; ++w) { + retElems[idx + w * kWidth / vecWidth] = + insert_element(retElems[idx + w * kWidth / vecWidth], + extract_element(canonval, i32_val(w)), i32_val(e)); + } + } } - if (elemBytes == 1) - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; - else - return {retElems[0], retElems[1], retElems[2], retElems[3]}; + return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), + bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; } - - assert(false && "Invalid smem load"); - return {Value{}, Value{}, Value{}, Value{}}; } MMA16816SmemLoader::MMA16816SmemLoader( - int wpt, ArrayRef order, uint32_t kOrder, + int wpt, ArrayRef order, uint32_t kOrder, int kWidth, ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc) - : order(order.begin(), order.end()), kOrder(kOrder), + : order(order.begin(), order.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), instrShape(instrShape.begin(), instrShape.end()), matShape(matShape.begin(), matShape.end()), perPhase(perPhase), @@ -369,7 +378,8 @@ MMA16816SmemLoader::MMA16816SmemLoader( // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; - canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16 + canUseLdmatrix = elemBytes == 2 || (!needTrans); + canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes); if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwpt] if not transposed, @@ -409,42 +419,12 @@ Type getShemPtrTy(Type argType) { return ptr_ty(type::i16Ty(ctx), 3); else if (argType.isF32()) return ptr_ty(type::f32Ty(ctx), 3); - else if (argType.isInteger(8)) + else if (argType.getIntOrFloatBitWidth() == 8) return ptr_ty(type::i8Ty(ctx), 3); else llvm::report_fatal_error("mma16816 data type not supported"); } -Type getMatType(Type argType) { - MLIRContext *ctx = argType.getContext(); - // floating point types - Type fp32x1Ty = vec_ty(type::f32Ty(ctx), 1); - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type i16x2Ty = vec_ty(type::i16Ty(ctx), 2); - Type fp16x2Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); - // LLVM 14.0 does not support bf16 type, so we use i16 instead. - Type bf16x2Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i16x2Ty)); - Type fp32Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32x1Ty)); - // integer types - Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); - Type i8x4Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); - - if (argType.isF16()) - return fp16x2Pack4Ty; - else if (argType.isBF16()) - return bf16x2Pack4Ty; - else if (argType.isF32()) - return fp32Pack4Ty; - else if (argType.isInteger(8)) - return i8x4Pack4Ty; - else - llvm::report_fatal_error("mma16816 data type not supported"); -} - Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int n0, int n1, TritonGPUToLLVMTypeConverter *typeConverter, Location loc, @@ -470,7 +450,7 @@ Value composeValuesToDotOperandLayoutStruct( std::function getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, - MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, + MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, int kWidth, SmallVector instrShape, SmallVector matShape, Value warpId, Value lane, ValueTable &vals, bool isA, TritonGPUToLLVMTypeConverter *typeConverter, @@ -485,143 +465,105 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; auto order = sharedLayout.getOrder(); - // the original register_lds2, but discard the prefetch logic. - auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { - vals[{mn, k}] = val; - }; - // (a, b) is the coordinate. - auto load = [=, &rewriter, &vals, &ld2](int a, int b) { + auto load = [=, &rewriter, &vals](int a, int b) { MMA16816SmemLoader loader( - wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, + wpt, sharedLayout.getOrder(), kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(warpId, lane, cSwizzleOffset); + // initialize pointers const int numPtrs = loader.getNumPtrs(); SmallVector ptrs(numPtrs); - Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); - Type smemPtrTy = getShemPtrTy(eltTy); - for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = - bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy); - } - + for (int i = 0; i < numPtrs; ++i) + ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy); + // actually load from shared memory + auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(), + SmallVector(4, i32_ty)); auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, - ptrs, getMatType(eltTy), getShemPtrTy(eltTy)); - - if (isA) { - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha1); - ld2(vals, a, b + 1, ha2); - ld2(vals, a + 1, b + 1, ha3); - } else { - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha2); - ld2(vals, a, b + 1, ha1); - ld2(vals, a + 1, b + 1, ha3); - } + ptrs, matTy, getShemPtrTy(eltTy)); + if (!isA) + std::swap(ha1, ha2); + // the following is incorrect + // but causes dramatically better performance in ptxas + // although it only changes the order of operands in + // `mma.sync` + // if(isA) + // std::swap(ha1, ha2); + // update user-provided values in-place + vals[{a, b}] = ha0; + vals[{a + 1, b}] = ha1; + vals[{a, b + 1}] = ha2; + vals[{a + 1, b + 1}] = ha3; }; return load; } -Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value tensor, - DotOperandEncodingAttr aEncoding, const SharedMemoryObject &smemObj, - TritonGPUToLLVMTypeConverter *typeConverter, Value thread) { - auto aTensorTy = tensor.getType().cast(); - int bitwidth = aTensorTy.getElementTypeBitWidth(); - auto mmaLayout = aEncoding.getParent().cast(); - - SmallVector shape(aTensorTy.getShape().begin(), - aTensorTy.getShape().end()); - - ValueTable ha; - std::function loadFn; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - - auto numRep = aEncoding.getMMAv2Rep(aTensorTy.getShape(), bitwidth); - int numRepM = numRep[0]; - int numRepK = numRep[1]; - - if (aTensorTy.getEncoding().isa()) { - int wpt0 = mmaLayout.getWarpsPerCTA()[0]; - Value warp = udiv(thread, i32_val(32)); - Value lane = urem(thread, i32_val(32)); - Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16)); - // load from smem - // we use ldmatrix.x4 so each warp processes 16x16 elements. - int wpt = std::min(wpt0, shape[0] / 16); - loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, - {mmaInstrM, mmaInstrK} /*instrShape*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/, - ha /*vals*/, true /*isA*/, typeConverter /* typeConverter */, - rewriter /*rewriter*/, loc /*loc*/); - } else if (aTensorTy.getEncoding().isa()) { - // load from registers, used in gemm fuse - // TODO(Superjomn) Port the logic. - assert(false && "Loading A from register is not supported yet."); - } else { - assert(false && "A's layout is not supported."); - } - - // step1. Perform loading. - for (int m = 0; m < numRepM; ++m) - for (int k = 0; k < numRepK; ++k) - loadFn(2 * m, 2 * k); - - // step2. Format the values to LLVM::Struct to passing to mma codegen. - return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK, - typeConverter, loc, rewriter); -} - -Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor, - DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, - TritonGPUToLLVMTypeConverter *typeConverter, Value thread) { - ValueTable hb; +Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor, + DotOperandEncodingAttr encoding, + const SharedMemoryObject &smemObj, + TritonGPUToLLVMTypeConverter *typeConverter, Value thread, + bool isA) { auto tensorTy = tensor.getType().cast(); int bitwidth = tensorTy.getElementTypeBitWidth(); - auto mmaLayout = bEncoding.getParent().cast(); + auto mmaLayout = encoding.getParent().cast(); SmallVector shape(tensorTy.getShape().begin(), tensorTy.getShape().end()); + ValueTable vals; int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - auto numRep = bEncoding.getMMAv2Rep(tensorTy.getShape(), bitwidth); - int numRepK = numRep[0]; - int numRepN = numRep[1]; + auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth); + int kWidth = encoding.getMMAv2kWidth(); int wpt0 = mmaLayout.getWarpsPerCTA()[0]; int wpt1 = mmaLayout.getWarpsPerCTA()[1]; Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); + Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16)); Value warpMN = udiv(warp, i32_val(wpt0)); Value warpN = urem(urem(warpMN, i32_val(wpt1)), i32_val(shape[1] / 8)); - // we use ldmatrix.x4 so each warp processes 16x16 elements. - int wpt = std::min(wpt1, shape[1] / 16); - auto loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, - {mmaInstrK, mmaInstrN} /*instrShape*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/, - hb /*vals*/, false /*isA*/, typeConverter /* typeConverter */, - rewriter /*rewriter*/, loc /*loc*/); - - for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { + + int wpt; + if (isA) + wpt = std::min(wpt0, shape[0] / 16); + else + wpt = std::min(wpt1, shape[1] / 16); + + std::function loadFn; + if (isA) + loadFn = getLoadMatrixFn( + tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, kWidth, + {mmaInstrM, mmaInstrK} /*instrShape*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/, + vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */, + rewriter /*rewriter*/, loc /*loc*/); + else + loadFn = getLoadMatrixFn( + tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, kWidth, + {mmaInstrK, mmaInstrN} /*instrShape*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/, + vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */, + rewriter /*rewriter*/, loc /*loc*/); + + // Perform loading. + int numRepOuter = isA ? numRep[0] : std::max(numRep[1] / 2, 1); + int numRepK = isA ? numRep[1] : numRep[0]; + for (int m = 0; m < numRepOuter; ++m) for (int k = 0; k < numRepK; ++k) - loadFn(2 * n, 2 * k); - } + loadFn(2 * m, 2 * k); - Value result = composeValuesToDotOperandLayoutStruct( - hb, std::max(numRepN / 2, 1), numRepK, typeConverter, loc, rewriter); - return result; + // Format the values to LLVM::Struct to passing to mma codegen. + return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK, + typeConverter, loc, rewriter); } namespace SharedToDotOperandMMAv2 { @@ -630,12 +572,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, const SharedMemoryObject &smemObj, TritonGPUToLLVMTypeConverter *typeConverter, Value thread) { if (opIdx == 0) - return loadA(rewriter, loc, tensor, encoding, smemObj, typeConverter, - thread); + return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter, + thread, true); else { assert(opIdx == 1); - return loadB(rewriter, loc, tensor, encoding, smemObj, typeConverter, - thread); + return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter, + thread, false); } } } // namespace SharedToDotOperandMMAv2 diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 26726910cd51..c4fd4acfbd40 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -4,11 +4,140 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::getTotalElemsPerThread; +static SmallVector reorderValues(const SmallVector &values, + Type inType, Type ouType) { + auto inTensorTy = inType.dyn_cast(); + auto ouTensorTy = ouType.dyn_cast(); + if (!inTensorTy || !ouTensorTy) + return values; + auto inEncoding = + dyn_cast(inTensorTy.getEncoding()); + auto ouEncoding = + dyn_cast(ouTensorTy.getEncoding()); + assert(inEncoding == ouEncoding); + if (!inEncoding) + return values; + size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); + size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); + auto ouEltTy = ouTensorTy.getElementType(); + if (inBitWidth == ouBitWidth) + return values; + if (inBitWidth == 16 && ouBitWidth == 32) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 8) { + ret.push_back(values[i]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + } + return ret; + } + if (inBitWidth == 8 && ouBitWidth == 16) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + ret.push_back(values[i + 12]); + ret.push_back(values[i + 13]); + ret.push_back(values[i + 14]); + ret.push_back(values[i + 15]); + } + return ret; + // for (unsigned i = 0; i < values.size(); i += 16) { + // ret.push_back(values[i]); + // ret.push_back(values[i + 1]); + // ret.push_back(values[i + 4]); + // ret.push_back(values[i + 5]); + // ret.push_back(values[i + 8]); + // ret.push_back(values[i + 9]); + // ret.push_back(values[i + 12]); + // ret.push_back(values[i + 13]); + + // ret.push_back(values[i + 2]); + // ret.push_back(values[i + 3]); + // ret.push_back(values[i + 6]); + // ret.push_back(values[i + 7]); + // ret.push_back(values[i + 10]); + // ret.push_back(values[i + 11]); + // ret.push_back(values[i + 14]); + // ret.push_back(values[i + 15]); + // } + return values; + } + llvm_unreachable("unimplemented code path"); +} + +inline SmallVector unpackI32(const SmallVector &inValues, + Type srcTy, + ConversionPatternRewriter &rewriter, + Location loc, + TypeConverter *typeConverter) { + auto tensorTy = srcTy.dyn_cast(); + if (!tensorTy) + return inValues; + auto encoding = tensorTy.getEncoding().dyn_cast(); + if (!(encoding && encoding.getParent().isa())) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +inline SmallVector packI32(const SmallVector &inValues, + Type srcTy, + ConversionPatternRewriter &rewriter, + Location loc, TypeConverter *typeConverter) { + auto tensorTy = srcTy.dyn_cast(); + if (!tensorTy) + return inValues; + auto encoding = tensorTy.getEncoding().dyn_cast(); + if (!(encoding && encoding.getParent().isa())) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + struct FpToFpOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern; + typedef std::function( + Location, ConversionPatternRewriter &, const Value &, const Value &, + const Value &, const Value &)> + ConvertorT; /* ------------------ */ // FP8 -> FP16 /* ------------------ */ @@ -490,35 +619,14 @@ struct FpToFpOpConversion return builder.launch(rewriter, loc, f16_ty, false); } - LogicalResult - matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTensorType = op.getFrom().getType().cast(); - auto dstTensorType = - op.getResult().getType().cast(); - auto srcEltType = srcTensorType.getElementType(); - auto dstEltType = dstTensorType.getElementType(); - auto loc = op->getLoc(); - auto elems = getTotalElemsPerThread(dstTensorType); - SmallVector resultVals; - bool isSrcFP8 = - srcEltType.isa(); - bool isDstFP8 = - dstEltType.isa(); - - // Select convertor - typedef std::function( - Location, ConversionPatternRewriter &, const Value &, const Value &, - const Value &, const Value &)> - ConvertorT; - + ConvertorT getConversionFunc(Type srcTy, Type dstTy) const { auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); auto F32TyID = TypeID::get(); auto F64TyID = TypeID::get(); - DenseMap, ConvertorT> convertorMap = { + static DenseMap, ConvertorT> convertorMap = { // F8 -> F16 {{F8E4M3TyID, F16TyID}, convertFp8E4M3x4ToFp16x4}, {{F8E5M2TyID, F16TyID}, convertFp8E5M2x4ToFp16x4}, @@ -539,28 +647,46 @@ struct FpToFpOpConversion {{F32TyID, F8E5M2TyID}, convertFp32x4ToFp8E5M2x4}, }; - std::pair key = {srcEltType.getTypeID(), - dstEltType.getTypeID()}; + std::pair key = {srcTy.getTypeID(), dstTy.getTypeID()}; if (convertorMap.count(key) == 0) { - llvm::errs() << "Unsupported conversion from " << srcEltType << " to " - << dstEltType << "\n"; + llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy + << "\n"; llvm_unreachable(""); } - auto convertor = convertorMap.lookup(key); + return convertorMap.lookup(key); + } - // Vectorized casting + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // llvm::outs() << 0 << "\n"; + auto srcTensorType = op.getFrom().getType().cast(); + auto dstTensorType = + op.getResult().getType().cast(); + auto loc = op->getLoc(); + // check that the number of elements is divisible by 4 + // Get convertor + auto cvtFunc = getConversionFunc(srcTensorType.getElementType(), + dstTensorType.getElementType()); + // Unpack value + auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getFrom(), + rewriter, srcTensorType); + inVals = + unpackI32(inVals, srcTensorType, rewriter, loc, getTypeConverter()); + // Cast + SmallVector outVals; + auto elems = inVals.size(); assert(elems % 4 == 0 && "FP8 casting only support tensors with 4-aligned sizes"); - auto elements = getTypeConverter()->unpackLLElements( - loc, adaptor.getFrom(), rewriter, srcTensorType); - for (size_t i = 0; i < elems; i += 4) { - auto converted = convertor(loc, rewriter, elements[i], elements[i + 1], - elements[i + 2], elements[i + 3]); - resultVals.append(converted); - } - - assert(resultVals.size() == elems); - auto result = getTypeConverter()->packLLElements(loc, resultVals, rewriter, + for (size_t i = 0; i < elems; i += 4) + outVals.append(cvtFunc(loc, rewriter, inVals[i], inVals[i + 1], + inVals[i + 2], inVals[i + 3])); + // Pack values + assert(outVals.size() == elems); + outVals = reorderValues(outVals, srcTensorType, dstTensorType); + outVals = + packI32(outVals, dstTensorType, rewriter, loc, getTypeConverter()); + auto result = getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTensorType); rewriter.replaceOp(op, result); return success(); @@ -582,43 +708,44 @@ class ElementwiseOpConversionBase ConversionPatternRewriter &rewriter) const override { auto resultTy = op.getType(); Location loc = op->getLoc(); - - unsigned elems = getTotalElemsPerThread(resultTy); + // element type auto resultElementTy = getElementTypeOrSelf(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultElementTy); - SmallVector types(elems, elemTy); - Type structTy = this->getTypeConverter()->convertType(resultTy); - - auto *concreteThis = static_cast(this); - auto operands = getOperands(rewriter, adaptor, resultTy, elems, loc); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy, - operands[i], loc); - if (!bool(resultVals[i])) + SmallVector resultVals; + // + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto sub_operands = this->getTypeConverter()->unpackLLElements( + loc, operand, rewriter, argTy); + sub_operands = unpackI32(sub_operands, argTy, rewriter, loc, + this->getTypeConverter()); + allOperands.resize(sub_operands.size()); + for (auto v : llvm::enumerate(sub_operands)) + allOperands[v.index()].push_back(v.value()); + } + if (allOperands.size() == 0) + allOperands.push_back({}); + for (const SmallVector &operands : allOperands) { + Value curr = + ((ConcreteT *)(this)) + ->createDestOp(op, adaptor, rewriter, elemTy, operands, loc); + if (!bool(curr)) return failure(); + resultVals.push_back(curr); + } + if (op->getNumOperands() > 0) { + auto argTy = op->getOperand(0).getType(); + resultVals = reorderValues(resultVals, argTy, resultTy); } + resultVals = + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = this->getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); return success(); } - -protected: - SmallVector> - getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor, - Type operandTy, const unsigned elems, Location loc) const { - SmallVector> operands(elems); - for (auto operand : adaptor.getOperands()) { - auto sub_operands = this->getTypeConverter()->unpackLLElements( - loc, operand, rewriter, operandTy); - for (size_t i = 0; i < elems; ++i) { - operands[i].push_back(sub_operands[i]); - } - } - return operands; - } }; template diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 347988b83e2a..ef8c474bd34c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -106,17 +106,8 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( return elemTy; if (mmaParent.isAmpere()) { int bitwidth = elemTy.getIntOrFloatBitWidth(); - // sub-word integer types need to be packed for perf reasons - if (elemTy.isa() && bitwidth < 32) - return IntegerType::get(ctx, 32); - // TODO: unify everything to use packed integer-types - // otherwise, vector types are ok - const llvm::DenseMap elemTyMap = { - {32, vec_ty(elemTy, 1)}, - {16, vec_ty(elemTy, 2)}, - {8, vec_ty(elemTy, 4)}, - }; - return elemTyMap.lookup(bitwidth); + assert(bitwidth <= 32); + return IntegerType::get(ctx, 32); } else { assert(mmaParent.isVolta()); return vec_ty(elemTy, 2); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 8a569156eeef..8d27f868cae4 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -80,6 +80,7 @@ #define call(...) rewriter.create(loc, __VA_ARGS__) // Types +#define int_ty(width) rewriter.getIntegerType(width) #define i64_ty rewriter.getIntegerType(64) #define i32_ty rewriter.getIntegerType(32) #define i16_ty rewriter.getIntegerType(16) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index a30af6d6be38..32605e605dea 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -72,6 +72,76 @@ class ConvertTransConvert : public mlir::RewritePattern { } }; +// + +class MoveOpAfterLayoutConversion : public mlir::RewritePattern { + +public: + MoveOpAfterLayoutConversion(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto srcTy = cvt.getOperand().getType().cast(); + auto retTy = cvt.getResult().getType().dyn_cast(); + auto retEncoding = + retTy.getEncoding().dyn_cast(); + auto srcEncoding = + srcTy.getEncoding().dyn_cast(); + if (!retTy) + return failure(); + if (!retEncoding) + return failure(); + auto retEncodingParent = + retEncoding.getParent().dyn_cast(); + if (!retEncodingParent || retEncodingParent.isVolta()) + return failure(); + if (!srcEncoding) + return failure(); + // don't move things around when cvt operand is a block arg + Operation *argOp = cvt.getOperand().getDefiningOp(); + if (!argOp) + return failure(); + // + SetVector processed; + SetVector layout; + llvm::MapVector toConvert; + int numCvts = simulateBackwardRematerialization(cvt, processed, layout, + toConvert, retEncoding); + if (numCvts > 1 || toConvert.size() == 1) + return failure(); + for (Operation *op : processed) { + if (op->getNumOperands() != 1) + continue; + auto srcTy = op->getOperand(0).getType().cast(); + auto dstTy = op->getResult(0).getType().cast(); + // we don't want to push conversions backward if there is a downcast + // since it would result in more shared memory traffic + if (srcTy.getElementType().getIntOrFloatBitWidth() > + dstTy.getElementType().getIntOrFloatBitWidth()) + return failure(); + // we only push back when the first op in the chain has a load operand + if ((op == processed.back()) && + !isa(op->getOperand(0).getDefiningOp())) + return failure(); + // we don't want to use ldmatrix for 8-bit data that requires trans + // since Nvidia GPUs can't do it efficiently + bool isTrans = + (retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0); + bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8; + if (isTrans && isInt8) + return failure(); + } + IRMapping mapping; + rematerializeConversionChain(toConvert, rewriter, processed, mapping); + rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); + return mlir::success(); + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -93,6 +163,7 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); if (fixupLoops(m).failed()) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 3ab91f808dc6..4dd0a2186fb0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -571,12 +571,15 @@ scf::ForOp LoopPipeliner::createNewForOp() { // 2. clone the loop body, replace original args with args of the new ForOp // Insert async wait if necessary. + DenseSet isModified; for (Operation &op : forOp.getBody()->without_terminator()) { + // is modified auto it = std::find(loads.begin(), loads.end(), op.getOperand(0)); if (it == loads.end()) { - Operation *newOp = builder.clone(op, mapping); + Operation *newOp = cloneWithInferType(builder, &op, mapping); continue; } + // we replace the use new load use with a convert layout size_t i = std::distance(loads.begin(), it); auto cvtDstTy = op.getResult(0).getType().cast(); @@ -590,6 +593,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { op.getResult(0).getLoc(), newDstTy, newForOp.getRegionIterArgs()[loadIdx + i]); mapping.map(op.getResult(0), cvt.getResult()); + isModified.insert(op.getResult(0)); } // 3. prefetch the next iteration diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index fde9fc9c2148..60c2cb95fca8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -45,7 +45,7 @@ class Prefetcher { scf::YieldOp yieldOp; /// // TODO: add a hook to infer prefetchWidth - unsigned prefetchWidth = 16; + unsigned prefetchWidth = 32; /// dots to be prefetched SetVector dots; @@ -56,6 +56,8 @@ class Prefetcher { DenseMap dot2bHeaderDef; DenseMap dot2aYield; DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; /// operand => defining DenseMap operand2headPrefetch; @@ -66,6 +68,9 @@ class Prefetcher { std::optional offsetK = std::nullopt, std::optional shapeK = std::nullopt); + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + public: Prefetcher() = delete; @@ -80,6 +85,24 @@ class Prefetcher { scf::ForOp createNewForOp(); }; +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[0], ret); + for (int i = 1; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + auto retType = RankedTensorType::get( + ret.getType().cast().getShape(), + curr.getType().cast().getElementType(), + curr.getType().cast().getEncoding()); + curr.setType(retType); + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, Attribute dotEncoding, OpBuilder &builder, std::optional offsetK, @@ -135,11 +158,32 @@ LogicalResult Prefetcher::initialize() { return failure(); // returns source of cvt - auto getPrefetchSrc = [](Value v) -> Value { - if (auto cvt = v.getDefiningOp()) - if (isSharedEncoding(cvt.getOperand())) - return cvt.getSrc(); - return Value(); + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + while (op) { + if (op->getNumOperands() != 1) + break; + if (!op->getResult(0).hasOneUse()) + break; + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast_or_null(op)) + if (isSharedEncoding(cvt.getOperand())) { + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + } + std::reverse(rets.begin(), rets.end()); + + if (foundConvertFromShared) + return rets; + return {}; }; auto getIncomingOp = [this](Value v) -> Value { @@ -176,14 +220,19 @@ LogicalResult Prefetcher::initialize() { // Skip prefetching if kSize is less than prefetchWidth if (kSize < prefetchWidth) continue; - Value aSmem = getPrefetchSrc(dot.getA()); - Value bSmem = getPrefetchSrc(dot.getB()); - if (aSmem && bSmem) { + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + + if (aVals.size() && bVals.size()) { + Value aSmem = aVals.front(); + Value bSmem = bVals.front(); Value aHeaderDef = getIncomingOp(aSmem); Value bHeaderDef = getIncomingOp(bSmem); // Only prefetch loop arg if (aHeaderDef && bHeaderDef) { dots.insert(dot); + dot2aVals[dot] = aVals; + dot2bVals[dot] = bVals; dot2aHeaderDef[dot] = aHeaderDef; dot2bHeaderDef[dot] = bHeaderDef; dot2aLoopArg[dot] = aSmem; @@ -205,10 +254,13 @@ void Prefetcher::emitPrologue() { dot.getType().cast().getEncoding(); Value aPrefetched = generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder); - operand2headPrefetch[dot.getDefiningOp().getA()] = - aPrefetched; + cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder); Value bPrefetched = generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder); + cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder); + + operand2headPrefetch[dot.getDefiningOp().getA()] = + aPrefetched; operand2headPrefetch[dot.getDefiningOp().getB()] = bPrefetched; } @@ -266,9 +318,11 @@ scf::ForOp Prefetcher::createNewForOp() { Value aRem = generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); Value bRem = generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); builder.restoreInsertionPoint(insertionPoint); newOp = builder.clone(*dot, mapping); newOp->setOperand(0, aRem); @@ -291,10 +345,15 @@ scf::ForOp Prefetcher::createNewForOp() { for (Value dot : dots) { Attribute dotEncoding = dot.getType().cast().getEncoding(); - yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, - true, dotEncoding, builder)); - yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, - true, dotEncoding, builder)); + Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, + dotEncoding, builder); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, + dotEncoding, builder); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); } // Update ops of yield if (!yieldValues.empty()) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index bc513cef9468..10b8e29dde68 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -209,6 +209,15 @@ int simulateBackwardRematerialization( Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + if (newOp->getNumResults() == 0) return newOp; auto origType = op->getResult(0).getType().dyn_cast(); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e5c352454095..d05ee816fa97 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -978,10 +978,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) + // CHECK-SAME: (i32, i32, i32, i32) // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) + // CHECK-SAME: (i32, i32, i32, i32) %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir new file mode 100644 index 000000000000..41a65cce45ec --- /dev/null +++ b/test/TritonGPU/dot-operands.mlir @@ -0,0 +1,52 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s + +#Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}> +#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}> +#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +// CHECK: tt.func @push_elementwise1 +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_elementwise1( + %pa: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL> + %b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL> + %dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av2> + %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv2> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + +// CHECK: tt.func @push_elementwise2 +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]] +// CHECK: %[[C:.*]] = tt.dot %[[ACVT]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1> +tt.func @push_elementwise2( + %pa: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{ + %ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL> + %b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL> + %dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av1> + %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv1> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1> + tt.return %newc : tensor<16x16xf32, #Cv1> +} diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 4f666b008d93..b820f4034abb 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-prefetch | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -11,29 +11,32 @@ #B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> -// CHECK: tt.func @matmul_loop +// CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16] // CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] // CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128] // CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] // CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_a0]][0, 16] [128, 16] // CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] // CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_b0]][16, 0] [16, 128] // CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [128, 16] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] // CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]] -tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { - %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> %a_mask = arith.constant dense : tensor<128x32xi1, #AL> - %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> @@ -41,24 +44,25 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> - %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL> + %a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A> %b_ = tt.load %b_ptr_init, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b_init = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op = triton_gpu.convert_layout %a : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> %b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> - %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL> + %next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %next_b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C> } - tt.return + tt.return %loop#4 : tensor<128x128xf32, #C> }