Skip to content

Commit

Permalink
[OPTIMIZER][BACKEND] Enabled elementwise ops (including casts) betwee…
Browse files Browse the repository at this point in the history
…n ldmatrix and mma.sync (triton-lang#1595)
  • Loading branch information
ptillet authored May 1, 2023
1 parent e7ef4ce commit e8093ff
Show file tree
Hide file tree
Showing 12 changed files with 549 additions and 289 deletions.
14 changes: 7 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,10 @@ struct ConvertLayoutOpConversion
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
}

Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
Expand All @@ -642,19 +642,19 @@ struct ConvertLayoutOpConversion
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor());

auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
Value res;

if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2

res = SharedToDotOperandMMAv2::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, getTypeConverter(), tid_val());

} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
} else if (!isOuter && mmaLayout.isVolta() &&
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
Expand Down
Loading

0 comments on commit e8093ff

Please sign in to comment.