forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathReduceOpToLLVM.cpp
436 lines (388 loc) · 17.8 KB
/
ReduceOpToLLVM.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
#include "ReduceScanCommon.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <vector>
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;
namespace {
struct ReduceOpConversion
: public ConvertTritonGPUReduceScanToLLVMPattern<triton::ReduceOp> {
public:
ReduceOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertTritonGPUReduceScanToLLVMPattern<triton::ReduceOp>(typeConverter,
benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ReduceOpHelper helper(op);
assert(helper.isSupportedLayout() &&
"Unexpected srcLayout in ReduceOpConversion");
Location loc = op->getLoc();
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// First reduce all the values along axis within each thread.
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
// Then reduce across threads within a warp.
reduceWithinWarps(helper, accs, rewriter);
if (helper.isWarpSynchronous()) {
// If all the values to be reduced are within the same warp there is
// nothing left to do.
packResults(helper, accs, rewriter);
return success();
}
// Compute a shared memory base per operand.
auto smemShape = helper.getScratchConfig();
SmallVector<Value> smemBases =
getSmemBases(op, product<unsigned>(smemShape), rewriter);
storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter);
sync(rewriter, loc, op);
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
accumulatePartialReductions(helper, smemBases, rewriter);
// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier in case the layouts are accepted.
sync(rewriter, loc, op);
// set output values
loadReductionAndPackResult(helper, smemShape, smemBases, rewriter);
return success();
}
private:
const TargetInfoBase &targetInfo;
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
if (isFirst) {
acc = SmallVector<Value>(cur.begin(), cur.end());
return;
}
// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, &parent.front());
auto &newReduce = parent.front();
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
combineArgs);
auto results = returnOp.getResult();
for (unsigned i = 0; i < acc.size(); ++i) {
acc[i] = results[i];
}
// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
}
SmallVector<SmallVector<Value>>
unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto types = op.getInputTypes();
auto operands = adaptor.getOperands();
unsigned srcElems = getTotalElemsPerThread(types[0]);
SmallVector<SmallVector<Value>> srcValues(srcElems);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto values = unpackLLElements(loc, operands[i], rewriter);
assert(values.size() == srcValues.size());
for (unsigned j = 0; j < srcValues.size(); ++j) {
srcValues[j].push_back(values[j]);
}
}
return srcValues;
}
void sync(ConversionPatternRewriter &rewriter, Location loc,
triton::ReduceOp op) const {
barrier();
}
// Reduce along op axis for elements that are in the same thread. The
// accumulated value is stored in accs.
void reduceWithinThreads(
ReduceOpHelper &helper, SmallVector<SmallVector<Value>> &srcValues,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
RankedTensorType operandType = op.getInputTypes()[0];
// Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offsets =
emitOffsetForLayout(helper.getSrcLayout(), operandType);
// Thread X might hold the same input value in two registers. Get the
// indices in `offsets` that hold unique values, and only accumualte over
// those.
llvm::MapVector<ArrayRef<unsigned>, int> uniqueOffsets;
for (int i = 0; i < offsets.size(); ++i) {
uniqueOffsets.insert({offsets[i], i});
}
unsigned srcElems = getTotalElemsPerThread(operandType);
auto *combineOp = &op.getCombineOp();
auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo,
helper.getSrcLayout(), operandType, true);
// reduce within threads
for (const auto &[_, i] : uniqueOffsets) {
SmallVector<unsigned> key = offsets[i];
key[op.getAxis()] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
}
// Apply warp reduction across the given number of contiguous lanes using op
// region and the accumulator values as source.
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce, unsigned interleave) const {
auto success =
targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce);
if (success)
return;
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
}
}
// Reduce across threads within each warp.
void
reduceWithinWarps(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned threadOffsetOnReductionAxis =
helper.getThreadOffsetOnReductionAxis();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = accs[key];
warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps,
threadOffsetOnReductionAxis);
}
}
// Pack the accumulator values and replace the reduce op with the result.
void packResults(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
unsigned axis = op.getAxis();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
dyn_cast<RankedTensorType>(op.getResult()[i].getType())) {
auto resultLayout = cast<SliceEncodingAttr>(resultTy.getEncoding());
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(accs[key][i]);
}
results[i] = packLLElements(loc, getTypeConverter(), resultVals,
rewriter, resultTy);
} else
results[i] = accs.begin()->second[i];
}
rewriter.replaceOp(op, results);
}
SmallVector<Value>
getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc,
ConversionPatternRewriter &rewriter) const {
auto srcLayout = helper.getSrcLayout();
auto srcShape = helper.getSrcShape();
auto order = getOrder(srcLayout);
SmallVector<Value> multiDimWarpId;
// 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same
// address as warpId = 0 since the warpsPerCTA is [1, 2], need to figure out
// a way to properly delinearize warpId in the slice case
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
auto parentLayout = sliceLayout.getParent();
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout);
auto parentOrder = triton::gpu::getOrder(parentLayout);
multiDimWarpId =
delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim());
} else {
SmallVector<unsigned> warpsPerCTA =
triton::gpu::getWarpsPerCTA(srcLayout);
warpsPerCTA[helper.getAxis()] = triton::gpu::getWarpsPerCTAWithUniqueData(
srcLayout, srcShape)[helper.getAxis()];
multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order);
}
return multiDimWarpId;
}
void storeWarpReduceToSharedMemory(
ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
Value threadId = getThreadId(rewriter, loc);
auto srcLayout = helper.getSrcLayout();
Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout));
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
auto srcShape = helper.getSrcShape();
unsigned axis = op.getAxis();
auto smemShape = helper.getScratchConfig();
auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
Value laneIdAxis = multiDimLaneId[axis];
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
SmallVector<Value> multiDimWarpId =
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShape, smemOrder);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy,
smemBases[i], writeOffset);
targetInfo.storeShared(rewriter, loc, writePtr, acc[i], laneZero);
}
}
}
// Load the reduction of each warp and accumulate them to a final value and
// store back to shared memory.
void accumulatePartialReductions(ReduceOpHelper &helper,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
auto srcLayout = helper.getSrcLayout();
auto smemShape = helper.getScratchConfig();
unsigned elems = product<unsigned>(smemShape);
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
Location loc = op.getLoc();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout));
Value laneId = urem(threadId, warpSize);
Value zero = i32_val(0);
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) *
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
SmallVector<Value> acc(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy,
smemBases[i], readOffset);
acc[i] = targetInfo.loadShared(rewriter, loc, getTypeConverter(),
readPtr, elemTy, threadIsNeeded);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */);
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
SmallVector<Value> writePtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy,
smemBases[i], writeOffset);
}
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred);
}
if (round != elemsPerThread - 1) {
readOffset = add(readOffset, i32_val(numThreads));
}
}
}
// Load the final reduction from shared memory and replace the reduce result
// with it.
void loadReductionAndPackResult(ReduceOpHelper &helper,
SmallVector<unsigned> smemShape,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
if (auto resultTy =
dyn_cast<RankedTensorType>(op.getResult()[i].getType())) {
// nd-tensor where n >= 1
auto resultLayout = cast<SliceEncodingAttr>(resultTy.getEncoding());
unsigned resultElems = getTotalElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
resultLayout, resultTy, true);
auto resultShape = resultTy.getShape();
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (size_t j = 0; j < resultElems; ++j) {
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
for (size_t resultIdx = 0, resultDim = resultShape.size();
resultIdx < resultDim; ++resultIdx) {
auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1;
if (resultCTATile[resultIdx] > smemShape[smemIdx] ||
resultShape[resultIdx] > smemShape[smemIdx]) {
// When srcShape smaller then src sizePerThread, only srcShape
// elements is accumulated in smem. Modulo smemShape effectively
// replicates srcShape elements to src sizePerThread.
readIdx[smemIdx] =
urem(readIdx[smemIdx], i32_val(smemShape[smemIdx]));
}
}
Value readOffset =
linearize(rewriter, loc, readIdx, smemShape, smemOrder);
Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy,
smemBases[i], readOffset);
resultVals[j] = load(elemTy, readPtr);
}
results[i] = packLLElements(loc, getTypeConverter(), resultVals,
rewriter, resultTy);
} else {
// 0d-tensor -> scalar
results[i] = load(elemTy, smemBases[i]);
}
}
rewriter.replaceOp(op, results);
}
};
} // namespace
void mlir::triton::populateReduceOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, targetInfo, benefit);
}