Skip to content

Commit

Permalink
[BACKEND][CPU] Convert tt.get_program_id and tt.print (Hello World) (#1)
Browse files Browse the repository at this point in the history
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
  • Loading branch information
minjang authored and ienkovich committed Dec 6, 2024
1 parent c214238 commit 047a677
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 9 deletions.
22 changes: 22 additions & 0 deletions include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"

namespace mlir::triton::cpu {
class CPUTargetInfo {
public:
// Note: we may revisit for different CPU ISAs like AVX and Neon.
CPUTargetInfo() {}

Value programId(ConversionPatternRewriter &rewriter, Location loc,
LLVM::LLVMFuncOp funcOp, int axis) const;

void printf(ConversionPatternRewriter &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const;

~CPUTargetInfo() {}
};
} // namespace mlir::triton::cpu
#endif // TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H

#include "CPUTargetInfo.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

using namespace mlir;
Expand All @@ -17,6 +19,11 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const cpu::CPUTargetInfo &targetInfo,
PatternBenefit benefit);

void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);
Expand All @@ -27,6 +34,7 @@ void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,

void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const CPUTargetInfo &targetInfo,
PatternBenefit benefit);

} // namespace cpu
Expand Down
13 changes: 4 additions & 9 deletions include/triton/Conversion/TritonCPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,10 @@
using namespace mlir;
using namespace mlir::triton;

namespace mlir {
namespace LLVM {
// TODO: Do better refactoring.
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

// TODO: Not sure we need this for CPU backends.
inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

} // namespace LLVM
} // namespace mlir
#undef DEBUG_TYPE
#define DEBUG_TYPE "ttcpu_to_llvm"

#endif
3 changes: 3 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
add_triton_library(TritonCPUToLLVM
ControlFlowOpToLLVM.cpp
CPUTargetInfo.cpp
FuncOpToLLVM.cpp
PrintOpToLLVM.cpp
SPMDOpToLLVM.cpp
TypeConverter.cpp
TritonCPUToLLVM.cpp

Expand Down
49 changes: 49 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"

namespace {
LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("printf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);

auto *context = rewriter.getContext();

// int printf(char* format, ...)
SmallVector<Type> argsType{ptr_ty(context)};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true);

ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());

return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
} // namespace

namespace mlir::triton::cpu {

Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter,
Location loc, LLVM::LLVMFuncOp funcOp,
int axis) const {
assert(axis >= 0 && axis < 3);

// program_id for CPU is provided as function arguments. The last three
// arguments are __grid0 to __grid2 of i32.
assert(funcOp && funcOp.getArguments().size() >= 3);
return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis);
}

void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter,
Value formatStrStart, int /*formatStrByteCount*/,
ValueRange args) const {
auto loc = UnknownLoc::get(rewriter.getContext());
SmallVector<Value> formatStrAndArgs{formatStrStart};
for (auto arg : args) {
formatStrAndArgs.push_back(arg);
}
call(getPrintfDeclaration(rewriter), formatStrAndArgs);
}
} // namespace mlir::triton::cpu
131 changes: 131 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h"
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace {

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
explicit PrintOpConversion(LLVMTypeConverter &typeConverter,
const CPUTargetInfo &targetInfo,
PatternBenefit benefit)
: mlir::ConvertOpToLLVMPattern<triton::PrintOp>(typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();

auto getPid = [&](int axis) {
return targetInfo.programId(
rewriter, loc, op->getParentOfType<LLVM::LLVMFuncOp>(), axis);
};
SmallVector<Value> values = {getPid(0), getPid(1), getPid(2)};

std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "pid (" << getFormatSubstr(values[0]) << ", "
<< getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2])
<< ")" << op.getPrefix();

for (size_t i = 0; i < op.getNumOperands(); i++) {
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
if (op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
llvm_unreachable("Not implemented for tensor types");
}

// Only support scalars for now.
assert(elems.size() == 1);
if (i != 0) {
os << ", ";
}
os << getFormatSubstr(elems[0]);
values.push_back(elems[0]);
}

llPrintf(formatStr, values, rewriter);
rewriter.eraseOp(op);
return success();
}

// TODO: This code is the same as the GPU-backend code. Consider refactoring.
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) const {
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
}
// Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the
// type (so 4 for fp16, 8 for int32, 16 for int64).
if (hex) {
// Ignore `width` for `hex` values, pad to typeWidth.
std::string ret =
"0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4);
if (type.getIntOrFloatBitWidth() > 32) {
ret += "ll";
}
ret += "x";
return ret;
}

std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
} else if (hex) {
prefix += "0";
prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4);
}

if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return prefix + "f";
} else if (type.isSignedInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "lli";
else
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "llu";
else
return prefix + "u";
}
assert(false && "not supported type");
return "";
}

Value llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter,
int *formatStrByteCount = nullptr) const {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
rewriter, "printfFormat_", msgNewline);
targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args);
if (formatStrByteCount)
*formatStrByteCount = msgNewline.size_in_bytes();
return msgValue;
}

protected:
const CPUTargetInfo &targetInfo;
};

} // namespace

void mlir::triton::cpu::populatePrintOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const CPUTargetInfo &targetInfo, PatternBenefit benefit) {
patterns.add<PrintOpConversion>(typeConverter, targetInfo, benefit);
}
39 changes: 39 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"

namespace {

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

struct GetProgramIdOpConversion
: public ConvertOpToLLVMPattern<triton::GetProgramIdOp> {
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter,
const CPUTargetInfo &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value programId = targetInfo.programId(
rewriter, op->getLoc(), op->getParentOfType<LLVM::LLVMFuncOp>(),
op.getAxisAsInt());
rewriter.replaceOp(op, programId);
return success();
}

private:
const CPUTargetInfo &targetInfo;
};

} // namespace

void mlir::triton::cpu::populateSPMDOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const CPUTargetInfo &targetInfo, PatternBenefit benefit) {
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit);
}
5 changes: 5 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,15 @@ struct ConvertTritonCPUToLLVM
}

RewritePatternSet patterns(context);
mlir::triton::cpu::CPUTargetInfo targetInfo;
int benefit =
mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions;
mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter,
patterns, benefit);
mlir::triton::cpu::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::cpu::populateSPMDOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
Expand Down
7 changes: 7 additions & 0 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,13 @@ def run(self, *args, grid, warmup, **kwargs):

configs = (backend.get_attrs_descriptor(self.params, bound_vals), )
constant_params = configs[0].get_constants()
# The CPU launcher will provide the grid ids directly to the kernel.
# Note that this design is interim and subject to change.
if target[0] == 'cpu':
signature["__grid0"] = 'i32'
signature["__grid1"] = 'i32'
signature["__grid2"] = 'i32'

constants = {
p.name: v
for (v, p) in zip(bound_vals, self.params)
Expand Down

0 comments on commit 047a677

Please sign in to comment.