forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND][CPU] Convert tt.get_program_id and tt.print (Hello World) (#1)
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments. I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing. Test Plan: Tested with a simple example: ``` @triton.jit def add_kernel(...): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. foo = pid + 42 tl.device_print("Hello, World!", foo, pid) ``` The resulting .llir is valid: ``` @printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00" declare !dbg !3 i32 @printf(ptr, ...) define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 { %8 = add i32 %4, 42, !dbg !8 %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4) ret void, !dbg !9 } ``` Tried to compile with a fake main function: ``` > % cat main.c extern void add_kernel(float*, float*, float*, int, int, int, int); int main() { add_kernel(0, 0, 0, 4, 5, 6, 7); } > % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c > % ./a.out pid (5, 6, 7) Hello, World!: 47, 5 ```
- Loading branch information
Showing
9 changed files
with
268 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H | ||
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H | ||
|
||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "triton/Conversion/MLIRTypes.h" | ||
|
||
namespace mlir::triton::cpu { | ||
class CPUTargetInfo { | ||
public: | ||
// Note: we may revisit for different CPU ISAs like AVX and Neon. | ||
CPUTargetInfo() {} | ||
|
||
Value programId(ConversionPatternRewriter &rewriter, Location loc, | ||
LLVM::LLVMFuncOp funcOp, int axis) const; | ||
|
||
void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, | ||
int formatStrByteCount, ValueRange args) const; | ||
|
||
~CPUTargetInfo() {} | ||
}; | ||
} // namespace mlir::triton::cpu | ||
#endif // TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" | ||
#include "triton/Conversion/TritonCPUToLLVM/Utility.h" | ||
|
||
namespace { | ||
LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { | ||
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>(); | ||
StringRef funcName("printf"); | ||
Operation *funcOp = moduleOp.lookupSymbol(funcName); | ||
if (funcOp) | ||
return cast<LLVM::LLVMFuncOp>(*funcOp); | ||
|
||
auto *context = rewriter.getContext(); | ||
|
||
// int printf(char* format, ...) | ||
SmallVector<Type> argsType{ptr_ty(context)}; | ||
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); | ||
|
||
ConversionPatternRewriter::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointToStart(moduleOp.getBody()); | ||
|
||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName, | ||
funcType); | ||
} | ||
} // namespace | ||
|
||
namespace mlir::triton::cpu { | ||
|
||
Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter, | ||
Location loc, LLVM::LLVMFuncOp funcOp, | ||
int axis) const { | ||
assert(axis >= 0 && axis < 3); | ||
|
||
// program_id for CPU is provided as function arguments. The last three | ||
// arguments are __grid0 to __grid2 of i32. | ||
assert(funcOp && funcOp.getArguments().size() >= 3); | ||
return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis); | ||
} | ||
|
||
void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter, | ||
Value formatStrStart, int /*formatStrByteCount*/, | ||
ValueRange args) const { | ||
auto loc = UnknownLoc::get(rewriter.getContext()); | ||
SmallVector<Value> formatStrAndArgs{formatStrStart}; | ||
for (auto arg : args) { | ||
formatStrAndArgs.push_back(arg); | ||
} | ||
call(getPrintfDeclaration(rewriter), formatStrAndArgs); | ||
} | ||
} // namespace mlir::triton::cpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" | ||
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" | ||
#include "triton/Conversion/TritonCPUToLLVM/Utility.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
namespace { | ||
|
||
using namespace mlir; | ||
using namespace mlir::triton; | ||
using namespace mlir::triton::cpu; | ||
|
||
struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> { | ||
explicit PrintOpConversion(LLVMTypeConverter &typeConverter, | ||
const CPUTargetInfo &targetInfo, | ||
PatternBenefit benefit) | ||
: mlir::ConvertOpToLLVMPattern<triton::PrintOp>(typeConverter, benefit), | ||
targetInfo(targetInfo) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto loc = op->getLoc(); | ||
|
||
auto getPid = [&](int axis) { | ||
return targetInfo.programId( | ||
rewriter, loc, op->getParentOfType<LLVM::LLVMFuncOp>(), axis); | ||
}; | ||
SmallVector<Value> values = {getPid(0), getPid(1), getPid(2)}; | ||
|
||
std::string formatStr; | ||
llvm::raw_string_ostream os(formatStr); | ||
os << "pid (" << getFormatSubstr(values[0]) << ", " | ||
<< getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) | ||
<< ")" << op.getPrefix(); | ||
|
||
for (size_t i = 0; i < op.getNumOperands(); i++) { | ||
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); | ||
if (op.getOperand(i).getType().dyn_cast<RankedTensorType>()) { | ||
llvm_unreachable("Not implemented for tensor types"); | ||
} | ||
|
||
// Only support scalars for now. | ||
assert(elems.size() == 1); | ||
if (i != 0) { | ||
os << ", "; | ||
} | ||
os << getFormatSubstr(elems[0]); | ||
values.push_back(elems[0]); | ||
} | ||
|
||
llPrintf(formatStr, values, rewriter); | ||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
|
||
// TODO: This code is the same as the GPU-backend code. Consider refactoring. | ||
std::string getFormatSubstr(Value value, bool hex = false, | ||
std::optional<int> width = std::nullopt) const { | ||
Type type = value.getType(); | ||
if (type.isa<LLVM::LLVMPointerType>()) { | ||
return "%p"; | ||
} | ||
// Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the | ||
// type (so 4 for fp16, 8 for int32, 16 for int64). | ||
if (hex) { | ||
// Ignore `width` for `hex` values, pad to typeWidth. | ||
std::string ret = | ||
"0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); | ||
if (type.getIntOrFloatBitWidth() > 32) { | ||
ret += "ll"; | ||
} | ||
ret += "x"; | ||
return ret; | ||
} | ||
|
||
std::string prefix = "%"; | ||
if (width.has_value()) { | ||
prefix += std::to_string(*width); | ||
} else if (hex) { | ||
prefix += "0"; | ||
prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); | ||
} | ||
|
||
if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { | ||
return prefix + "f"; | ||
} else if (type.isSignedInteger()) { | ||
if (type.getIntOrFloatBitWidth() == 64) | ||
return prefix + "lli"; | ||
else | ||
return prefix + "i"; | ||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) { | ||
if (type.getIntOrFloatBitWidth() == 64) | ||
return prefix + "llu"; | ||
else | ||
return prefix + "u"; | ||
} | ||
assert(false && "not supported type"); | ||
return ""; | ||
} | ||
|
||
Value llPrintf(StringRef msg, ValueRange args, | ||
ConversionPatternRewriter &rewriter, | ||
int *formatStrByteCount = nullptr) const { | ||
assert(!msg.empty() && "printf with empty string not supported"); | ||
llvm::SmallString<64> msgNewline(msg); | ||
msgNewline.push_back('\n'); | ||
msgNewline.push_back('\0'); | ||
Value msgValue = | ||
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), | ||
rewriter, "printfFormat_", msgNewline); | ||
targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); | ||
if (formatStrByteCount) | ||
*formatStrByteCount = msgNewline.size_in_bytes(); | ||
return msgValue; | ||
} | ||
|
||
protected: | ||
const CPUTargetInfo &targetInfo; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::triton::cpu::populatePrintOpToLLVMPattern( | ||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, | ||
const CPUTargetInfo &targetInfo, PatternBenefit benefit) { | ||
patterns.add<PrintOpConversion>(typeConverter, targetInfo, benefit); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" | ||
#include "triton/Conversion/TritonCPUToLLVM/Utility.h" | ||
|
||
namespace { | ||
|
||
using namespace mlir; | ||
using namespace mlir::triton; | ||
using namespace mlir::triton::cpu; | ||
|
||
struct GetProgramIdOpConversion | ||
: public ConvertOpToLLVMPattern<triton::GetProgramIdOp> { | ||
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, | ||
const CPUTargetInfo &targetInfo, | ||
PatternBenefit benefit = 1) | ||
: ConvertOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter, benefit), | ||
targetInfo(targetInfo) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Value programId = targetInfo.programId( | ||
rewriter, op->getLoc(), op->getParentOfType<LLVM::LLVMFuncOp>(), | ||
op.getAxisAsInt()); | ||
rewriter.replaceOp(op, programId); | ||
return success(); | ||
} | ||
|
||
private: | ||
const CPUTargetInfo &targetInfo; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::triton::cpu::populateSPMDOpToLLVMPattern( | ||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, | ||
const CPUTargetInfo &targetInfo, PatternBenefit benefit) { | ||
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters