diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 3b35892d401a..c71475c21039 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -343,7 +343,7 @@ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html (mma.16816 section, FP32 accumulator). For example, the matrix L corresponding to blockTileSize=[32,16] is: - warp 0 warp 1 + warp 0 warp 2 -----------------/\------------- ----------------/\------------- [ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 [ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 @@ -354,7 +354,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: [ .............................. .............................. [ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 - warp 3 warp 4 + warp 1 warp 3 ----------------/\------------- ----------------/\------------- [ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 [ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 145a320eed1e..2bffe4e32807 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -1,6 +1,8 @@ #include "../ConvertLayoutOpToLLVM.h" #include "../Utility.h" +using namespace mlir; + using ValueTable = std::map, Value>; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; @@ -30,13 +32,8 @@ class MMA16816SmemLoader { Value cSwizzleOffset) { if (canUseLdmatrix) return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); - else if (elemBytes == 4 && needTrans) - return computeB32MatOffs(warpOff, lane, cSwizzleOffset); - else if (elemBytes == 1 && needTrans) - return computeB8MatOffs(warpOff, lane, cSwizzleOffset); else - llvm::report_fatal_error("Invalid smem load config"); - + return computeLdsMatOffs(warpOff, lane, cSwizzleOffset, elemBytes); return {}; } @@ -46,14 +43,9 @@ class MMA16816SmemLoader { // mapped to. SmallVector computeLdmatrixMatOffs(Value warpId, Value lane, Value cSwizzleOffset); - - // Compute 32-bit matrix offsets. - SmallVector computeB32MatOffs(Value warpOff, Value lane, - Value cSwizzleOffset); - // compute 8-bit matrix offset. - SmallVector computeB8MatOffs(Value warpOff, Value lane, - Value cSwizzleOffset); + SmallVector computeLdsMatOffs(Value warpOff, Value lane, + Value cSwizzleOffset, int elemBytes); // Load 4 matrices and returns 4 vec<2> elements. std::tuple @@ -158,85 +150,97 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane, return offs; } -SmallVector MMA16816SmemLoader::computeB32MatOffs(Value warpOff, +// clang-format off +// Each `ldmatrix.x4` loads data as follows when `needTrans == False`: +// +// quad width +// <-----------------------------------------> +// vecWidth +// <-------> +// t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\ +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height +// ... | +// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ +// --------------------------------------------- || -------------------------------------------- +// t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 +// ... +// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 +// +// we assume that the phase is < 8 so we don't need to maintain a separate pointer for the two +// lower quadrants. This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks +// along the row (resp. col) dimension. +// clang-format on + +SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, Value lane, - Value cSwizzleOffset) { - assert(needTrans && "Only used in transpose mode."); - // Load tf32 matrices with lds32 - Value cOffInMat = udiv(lane, i32_val(4)); - Value sOffInMat = urem(lane, i32_val(4)); - - Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); - SmallVector offs(numPtrs); - - for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time - int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; - int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; - if (kMatArrInt > 0) // we don't need pointers for k - continue; - Value kMatArr = i32_val(kMatArrInt); - Value nkMatArr = i32_val(nkMatArrInt); - - Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), - mul(nkMatArr, i32_val(matArrStride))); - Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); - cMatOff = add(cMatOff, cSwizzleMatOff); - - Value sMatOff = kMatArr; - Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); - // FIXME: (kOrder == 1?) is really dirty hack - for (int i = 0; i < numPtrs / 2; ++i) { - Value cMatOffI = - add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); - cMatOffI = xor_(cMatOffI, phase); - Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); - cOff = urem(cOff, i32_val(tileShape[order[0]])); - sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride)); - } + Value cSwizzleOffset, + int elemBytes) { + assert(elemBytes <= 4); + int cTileShape = tileShape[order[0]]; + int sTileShape = tileShape[order[1]]; + if (!needTrans) { + std::swap(cTileShape, sTileShape); } - return offs; -} - -SmallVector MMA16816SmemLoader::computeB8MatOffs(Value warpOff, - Value lane, - Value cSwizzleOffset) { - assert(needTrans && "Only used in transpose mode."); - Value cOffInMat = udiv(lane, i32_val(4)); - Value sOffInMat = - mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols SmallVector offs(numPtrs); - for (int mat = 0; mat < 4; ++mat) { - int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; - int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; - if (kMatArrInt > 0) // we don't need pointers for k - continue; - Value kMatArr = i32_val(kMatArrInt); - Value nkMatArr = i32_val(nkMatArrInt); - - Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), - mul(nkMatArr, i32_val(matArrStride))); - Value sMatOff = kMatArr; - - for (int loadx4Off = 0; loadx4Off < numPtrs / 8; ++loadx4Off) { - for (int elemOff = 0; elemOff < 4; ++elemOff) { - int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; - Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * - (kOrder == 1 ? 1 : 2))); - Value sOffInMatElem = add(sOffInMat, i32_val(elemOff)); - - // disable swizzling ... - - Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); - Value sOff = add(sOffInMatElem, mul(sMatOff, i32_val(sMatShape))); + + int threadsPerQuad[2] = {8, 4}; + int laneWidth = 4; + int laneHeight = 8; + int vecWidth = 4 / elemBytes; + int quadWidth = laneWidth * vecWidth; + int quadHeight = laneHeight; + int numQuadI = 2; + + // outer index base + Value iBase = udiv(lane, i32_val(laneWidth)); + + for (int rep = 0; rep < numPtrs / (2 * vecWidth); ++rep) + for (int quadId = 0; quadId < 2; ++quadId) + for (int elemId = 0; elemId < vecWidth; ++elemId) { + int idx = rep * 2 * vecWidth + quadId * vecWidth + elemId; + // inner index base + Value jBase = mul(urem(lane, i32_val(laneWidth)), i32_val(vecWidth)); + jBase = add(jBase, i32_val(elemId)); + // inner index offset + Value jOff = i32_val(0); + if (!needTrans) { + jOff = add(jOff, i32_val(quadId)); + jOff = add(jOff, i32_val(rep * pLoadStrideInMat)); + } + // outer index offset + Value iOff = mul(warpOff, i32_val(warpOffStride)); + if (needTrans) { + int pStride = kOrder == 1 ? 1 : 2; + iOff = add(iOff, i32_val(quadId * matArrStride)); + iOff = add(iOff, i32_val(rep * pLoadStrideInMat * pStride)); + } + // swizzle + if (!needTrans) { + Value phase = urem(udiv(iBase, i32_val(perPhase)), i32_val(maxPhase)); + jOff = add(jOff, udiv(cSwizzleOffset, i32_val(quadWidth))); + jOff = xor_(jOff, phase); + } else { + Value phase = urem(udiv(jBase, i32_val(perPhase)), i32_val(maxPhase)); + iOff = add(iOff, udiv(cSwizzleOffset, i32_val(quadHeight))); + iOff = xor_(iOff, phase); + } // To prevent out-of-bound access when tile is too small. - cOff = urem(cOff, i32_val(tileShape[order[0]])); - sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[ptrOff] = add(cOff, mul(sOff, sStride)); + Value i = add(iBase, mul(iOff, i32_val(quadHeight))); + Value j = add(jBase, mul(jOff, i32_val(quadWidth))); + // wrap around the bounds + i = urem(i, i32_val(cTileShape)); + j = urem(j, i32_val(sTileShape)); + if (needTrans) { + offs[idx] = add(i, mul(j, sStride)); + } else { + offs[idx] = add(mul(i, sStride), j); + } } - } - } + return offs; } @@ -251,18 +255,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef offs, if (canUseLdmatrix) ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); - else if (elemBytes == 4 && needTrans) - ptrIdx = matIdx[order[0]]; - else if (elemBytes == 1 && needTrans) - ptrIdx = matIdx[order[0]] * 4; else - llvm::report_fatal_error("unsupported mma type found"); + ptrIdx = matIdx[order[0]] * 4 / elemBytes; // The main difference with the original triton code is we removed the // prefetch-related logic here for the upstream optimizer phase should // take care with it, and that is transparent in dot conversion. auto getPtr = [&](int idx) { return ptrs[idx]; }; - Value ptr = getPtr(ptrIdx); // The struct should have exactly the same element types. @@ -304,96 +303,47 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef offs, Value resV4 = builder.launch(rewriter, loc, resTy); return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1), extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; - } else if (elemBytes == 4 && needTrans) { // Use lds.32 to load tf32 matrices - Value ptr2 = getPtr(ptrIdx + 1); - assert(sMatStride == 1); - int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); - Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); - int sOffsetArrElem = sMatStride * sMatShape; - Value sOffsetArrElemVal = - add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); - - Value elems[4]; - if (kOrder == 1) { - elems[0] = load(gep(shemPtrTy, ptr, sOffsetElemVal)); - elems[1] = load(gep(shemPtrTy, ptr2, sOffsetElemVal)); - elems[2] = load(gep(shemPtrTy, ptr, sOffsetArrElemVal)); - elems[3] = load(gep(shemPtrTy, ptr2, sOffsetArrElemVal)); - } else { - elems[0] = load(gep(shemPtrTy, ptr, sOffsetElemVal)); - elems[2] = load(gep(shemPtrTy, ptr2, sOffsetElemVal)); - elems[1] = load(gep(shemPtrTy, ptr, sOffsetArrElemVal)); - elems[3] = load(gep(shemPtrTy, ptr2, sOffsetArrElemVal)); - } - std::array retElems; - retElems.fill(undef(elemTy)); - for (auto i = 0; i < 4; ++i) { - retElems[i] = insert_element(elemTy, retElems[i], elems[i], i32_val(0)); - } - return {retElems[0], retElems[1], retElems[2], retElems[3]}; - } else if (elemBytes == 1 && needTrans) { // work with int8 - // Can't use i32 here. Use LLVM's VectorType + } else { elemTy = matTy.cast().getBody()[0]; + // base pointers std::array, 2> ptrs; - ptrs[0] = { - getPtr(ptrIdx), - getPtr(ptrIdx + 1), - getPtr(ptrIdx + 2), - getPtr(ptrIdx + 3), - }; - - ptrs[1] = { - getPtr(ptrIdx + 4), - getPtr(ptrIdx + 5), - getPtr(ptrIdx + 6), - getPtr(ptrIdx + 7), - }; - - assert(sMatStride == 1); - int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); - Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); - int sOffsetArrElem = 1 * (sMatStride * sMatShape); - Value sOffsetArrElemVal = - add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); - - std::array i8v4Elems; - i8v4Elems.fill(undef(elemTy)); - - Value i8Elems[4][4]; - if (kOrder == 1) { - for (int i = 0; i < 2; ++i) - for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(shemPtrTy, ptrs[i][j], sOffsetElemVal)); - - for (int i = 2; i < 4; ++i) - for (int j = 0; j < 4; ++j) - i8Elems[i][j] = - load(gep(shemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal)); - - for (int m = 0; m < 4; ++m) { - for (int e = 0; e < 4; ++e) - i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], - i8Elems[m][e], i32_val(e)); - } - } else { // k first - for (int j = 0; j < 4; ++j) - i8Elems[0][j] = load(gep(shemPtrTy, ptrs[0][j], sOffsetElemVal)); - for (int j = 0; j < 4; ++j) - i8Elems[2][j] = load(gep(shemPtrTy, ptrs[1][j], sOffsetElemVal)); - for (int j = 0; j < 4; ++j) - i8Elems[1][j] = load(gep(shemPtrTy, ptrs[0][j], sOffsetArrElemVal)); - for (int j = 0; j < 4; ++j) - i8Elems[3][j] = load(gep(shemPtrTy, ptrs[1][j], sOffsetArrElemVal)); - - for (int m = 0; m < 4; ++m) { - for (int e = 0; e < 4; ++e) - i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], - i8Elems[m][e], i32_val(e)); - } + int vecWidth = 4 / elemBytes; + for (int i = 0; i < vecWidth; i++) + ptrs[0][i] = getPtr(ptrIdx + i); + for (int i = 0; i < vecWidth; i++) + ptrs[1][i] = getPtr(ptrIdx + i + vecWidth); + // static offsets along outer dimension + int _i0 = matIdx[order[1]] * (sMatStride * sMatShape); + int _i1 = _i0; + if (needTrans) + _i1 += sMatStride * sMatShape; + else + _i1 += (kOrder == 1 ? 1 : sMatStride) * sMatShape; + Value i0 = mul(i32_val(_i0), sStride); + Value i1 = mul(i32_val(_i1), sStride); + std::array ii = {i0, i1}; + // load 4 32-bit values from shared memory + // (equivalent to ldmatrix.x4) + SmallVector> vals(4, SmallVector(vecWidth)); + for (int i = 0; i < 4; ++i) + for (int j = 0; j < vecWidth; ++j) + vals[i][j] = load(gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2])); + // row + trans and col + no-trans are equivalent + if ((needTrans && kOrder == 1) || (!needTrans && kOrder == 0)) + std::swap(vals[1], vals[2]); + // pack loaded vectors into 4 32-bit values + std::array retElems; + retElems.fill(undef(elemTy)); + for (int m = 0; m < 4; ++m) { + for (int e = 0; e < vecWidth; ++e) + retElems[m] = insert_element(retElems[m].getType(), retElems[m], + vals[m][e], i32_val(e)); } - - return {bitcast(i8v4Elems[0], i32_ty), bitcast(i8v4Elems[1], i32_ty), - bitcast(i8v4Elems[2], i32_ty), bitcast(i8v4Elems[3], i32_ty)}; + if (elemBytes == 1) + return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), + bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; + else + return {retElems[0], retElems[1], retElems[2], retElems[3]}; } assert(false && "Invalid smem load"); @@ -427,13 +377,13 @@ MMA16816SmemLoader::MMA16816SmemLoader( numPtrs = tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; } else { - numPtrs = tileShape[order[0]] / wpt / matShape[order[0]]; + numPtrs = tileShape[order[0]] / (needTrans ? wpt : 1) / matShape[order[0]]; + numPtrs *= 4 / elemBytes; } numPtrs = std::max(numPtrs, 2); // Special rule for i8/u8, 4 ptrs for each matrix - if (!canUseLdmatrix && elemBytes == 1) - numPtrs *= 4; + // if (!canUseLdmatrix && elemBytes == 1) int loadStrideInMat[2]; loadStrideInMat[kOrder] = @@ -442,6 +392,7 @@ MMA16816SmemLoader::MMA16816SmemLoader( wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]); pLoadStrideInMat = loadStrideInMat[order[0]]; + sMatStride = loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]);