Skip to content

Commit

Permalink
Fix test_gather (#3010)
Browse files Browse the repository at this point in the history
Make `getStackPointer` as interface of the `TargetInfo` to generalize `getSharedMemoryBase` in gather op.
  • Loading branch information
ESI-SYD authored Dec 19, 2024
1 parent c83c0ed commit fdab3bb
Show file tree
Hide file tree
Showing 17 changed files with 67 additions and 65 deletions.
3 changes: 3 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class TargetInfoBase {

virtual bool supportVectorizedAtomics() const = 0;

virtual Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const = 0;

virtual ~TargetInfoBase() {}
};
} // namespace mlir::triton
Expand Down
16 changes: 2 additions & 14 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,6 @@ inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

inline Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
// See NOTE: [Additional Function Arguments]
if (!isKernel(funcOp)) {
return funcOp.getArgument(funcOp.getNumArguments() - 2);
}

auto mod = funcOp->getParentOfType<ModuleOp>();
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
assert(globalBase);
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
}

inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
FunctionOpInterface funcOp,
Value allocOffset = {}) {
Expand Down Expand Up @@ -457,7 +444,8 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
.getValue()
.getZExtValue();
Value offVal = i32_val(offset);
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
Value base =
gep(ptrTy, i8_ty, target.getStackPointer(rewriter, func), offVal);
return base;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
if (!caller->hasAttr("allocation.offset")) {
auto base = LLVM::getStackPointer(rewriter, caller);
auto base = targetInfo.getStackPointer(rewriter, caller);
promotedOperands.push_back(base);
} else {
auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp);
Expand Down
2 changes: 0 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6304,8 +6304,6 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
([128, 64], [128, 128], 1),
])
def test_gather(src_shape, indices_shape, axis, device):
if is_xpu():
pytest.skip("Fail on XPU")

def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
Expand Down
13 changes: 13 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,19 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,

int TargetInfo::getSharedAddressSpace() const { return 3; }

Value TargetInfo::getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const {
// See NOTE: [Additional Function Arguments]
if (!LLVM::isKernel(funcOp)) {
return funcOp.getArgument(funcOp.getNumArguments() - 2);
}

auto mod = funcOp->getParentOfType<ModuleOp>();
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
assert(globalBase);
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
}

bool TargetInfo::supportVectorizedAtomics() const {
// Note: not currently tested or used, but AMD generally supports vectorized
// atomics.
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {

bool supportVectorizedAtomics() const override;

Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const override;

private:
void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args,
RewriterBase &rewriter, bool useStdErr) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
if (!caller->hasAttr("allocation.offset")) {
auto base = LLVM::intel::getStackPointer(rewriter, caller);
auto base = targetInfo.getStackPointer(rewriter, caller);
promotedOperands.push_back(base);
return promotedOperands;
}
promotedOperands.push_back(LLVM::intel::getSharedMemoryBase(
promotedOperands.push_back(LLVM::getSharedMemoryBase(
callOp->getLoc(), rewriter, targetInfo, callOp));
return promotedOperands;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ struct ConvertLayoutOpConversion
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

Value smemBase = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo,
op.getOperation());
Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto shape = dstTy.getShape();
Expand Down Expand Up @@ -819,8 +819,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
Type elementType = inVals.front().getType();
auto mod = rewriter.getInsertionPoint()->getParentOfType<ModuleOp>();

Value smemBase = LLVM::intel::getSharedMemoryBase(
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
&*rewriter.getInsertionPoint());
Type ptrType = smemBase.getType();

int numRows = inVals.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ struct HistogramOpConversion
// TODO: we could skip this for cases with num_warps=1 as long as we can
// generate the right layout. Currently the warp level histogram generates
// data in the default blocked layout.
Value baseSharedMemPtr = LLVM::intel::getSharedMemoryBase(
loc, rewriter, targetInfo, op.getOperation());
Value baseSharedMemPtr =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto dstType = op.getType();
Attribute dstEncoding = dstType.getEncoding();
auto indices = ::intel::emitIndices(op.getLoc(), rewriter, targetInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,8 +1500,8 @@ struct AtomicCASOpConversion
rewriter.eraseOp(op);
return success();
}
Value atomPtr = LLVM::intel::getSharedMemoryBase(
loc, rewriter, targetInfo, op.getOperation());
Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3));
targetInfo.storeShared(rewriter, loc, atomPtr, ret, mask);
createBarrier(rewriter, loc, numCTAs);
Expand Down Expand Up @@ -1681,8 +1681,8 @@ struct AtomicRMWOpConversion
rewriter.eraseOp(op);
return success();
}
Value atomPtr = LLVM::intel::getSharedMemoryBase(
loc, rewriter, targetInfo, op.getOperation());
Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3));
// Only threads with rmwMask = True store the result
targetInfo.storeShared(rewriter, loc, atomPtr, ret, rmwMask);
Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ struct LocalAllocOpConversion
if (!op.isSharedMemoryAlloc())
return failure();
Location loc = op->getLoc();
Value smemBase = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo,
op.getOperation());
Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto resultTy = cast<MemDescType>(op.getType());
auto typeConverter = getTypeConverter();
auto sharedLayout =
Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class ConvertTritonIntelGPUReduceScanToLLVMPattern
});
// Assign base index to each operand in their order in indices
std::map<unsigned, Value> indexToBase;
auto basePtr = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo,
op.getOperation());
auto basePtr =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
indexToBase[indices[0]] = basePtr;
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
indexToBase[indices[i]] =
Expand Down
10 changes: 10 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,14 @@ bool TargetInfo::supportVectorizedAtomics() const {
return true;
}

Value TargetInfo::getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const {
auto mod = funcOp->getParentOfType<ModuleOp>();
LLVM::LLVMPointerType ptrTy = ptr_ty(
rewriter.getContext(), TritonGEN::TritonGENMemorySpace::kWorkgroup);
if (mod->getAttrOfType<IntegerAttr>("ttg.shared").getInt() == 0)
return rewriter.create<LLVM::PoisonOp>(funcOp.getLoc(), ptrTy);
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

} // namespace mlir::triton::intel
3 changes: 3 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {

bool supportVectorizedAtomics() const override;

Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const override;

private:
};
} // namespace mlir::triton::intel
Expand Down
32 changes: 0 additions & 32 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,6 @@ Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,

LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter);

static Value getStackPointer(PatternRewriter &rewriter,
FunctionOpInterface funcOp) {
auto mod = funcOp->getParentOfType<ModuleOp>();
LLVM::LLVMPointerType ptrTy = ptr_ty(
rewriter.getContext(), TritonGEN::TritonGENMemorySpace::kWorkgroup);
if (mod->getAttrOfType<IntegerAttr>("ttg.shared").getInt() == 0)
return rewriter.create<LLVM::PoisonOp>(funcOp.getLoc(), ptrTy);
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

static Value getSharedMemoryBase(Location loc,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &target, Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
target.getSharedAddressSpace());
FunctionOpInterface func = op->getParentOfType<FunctionOpInterface>();
// CI debugging usage here
if (!op->hasAttr("allocation.offset")) {
auto mod = op->getParentOfType<ModuleOp>();
llvm::errs() << "op: " << *op << "\n";
llvm::errs() << "mod:" << mod << "\n";
llvm_unreachable("missing allocation.offset");
}
size_t offset = cast<IntegerAttr>(op->getAttr("allocation.offset"))
.getValue()
.getZExtValue();
Value offVal = i32_val(offset);
Value base =
gep(ptrTy, i8_ty, LLVM::intel::getStackPointer(rewriter, func), offVal);
return base;
}

static Value getModuleWarpSize(RewriterBase &rewriter, Location loc) {
auto mod = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
return i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
Expand Down
13 changes: 13 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,19 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,

int TargetInfo::getSharedAddressSpace() const { return 3; }

Value TargetInfo::getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const {
// See NOTE: [Additional Function Arguments]
if (!LLVM::isKernel(funcOp)) {
return funcOp.getArgument(funcOp.getNumArguments() - 2);
}

auto mod = funcOp->getParentOfType<ModuleOp>();
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
assert(globalBase);
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
}

bool TargetInfo::supportVectorizedAtomics() const {
return computeCapability >= 90 && ptxVersion >= 81;
}
Expand Down
3 changes: 3 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
StringRef file, StringRef func, int line) const override;
int getSharedAddressSpace() const override;

Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) const override;

bool supportVectorizedAtomics() const override;

int getPtxVersion() const { return ptxVersion; }
Expand Down

0 comments on commit fdab3bb

Please sign in to comment.