diff --git a/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h b/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h new file mode 100644 index 000000000000..66f6b57b1c57 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton::cpu { +class CPUTargetInfo { +public: + // Note: we may revisit for different CPU ISAs like AVX and Neon. + CPUTargetInfo() {} + + Value programId(ConversionPatternRewriter &rewriter, Location loc, + LLVM::LLVMFuncOp funcOp, int axis) const; + + void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const; + + ~CPUTargetInfo() {} +}; +} // namespace mlir::triton::cpu +#endif // TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H diff --git a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h index d2212eb34009..f5cd3612dac5 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h +++ b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h @@ -1,7 +1,9 @@ #ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H #define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H +#include "CPUTargetInfo.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" using namespace mlir; @@ -17,6 +19,11 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; constexpr int patternBenefitClampOptimizedPattern = 20; constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const cpu::CPUTargetInfo &targetInfo, + PatternBenefit benefit); + void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); @@ -27,6 +34,7 @@ void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const CPUTargetInfo &targetInfo, PatternBenefit benefit); } // namespace cpu diff --git a/include/triton/Conversion/TritonCPUToLLVM/Utility.h b/include/triton/Conversion/TritonCPUToLLVM/Utility.h index 08d3b5e061a8..8562271340a1 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonCPUToLLVM/Utility.h @@ -12,15 +12,10 @@ using namespace mlir; using namespace mlir::triton; -namespace mlir { -namespace LLVM { +// TODO: Do better refactoring. +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -// TODO: Not sure we need this for CPU backends. -inline bool isKernel(FunctionOpInterface funcOp) { - return funcOp.getVisibility() == SymbolTable::Visibility::Public; -} - -} // namespace LLVM -} // namespace mlir +#undef DEBUG_TYPE +#define DEBUG_TYPE "ttcpu_to_llvm" #endif diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt index 175115628597..db507557fb22 100644 --- a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -1,6 +1,9 @@ add_triton_library(TritonCPUToLLVM ControlFlowOpToLLVM.cpp + CPUTargetInfo.cpp FuncOpToLLVM.cpp + PrintOpToLLVM.cpp + SPMDOpToLLVM.cpp TypeConverter.cpp TritonCPUToLLVM.cpp diff --git a/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp b/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp new file mode 100644 index 000000000000..8dd050b80bbf --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp @@ -0,0 +1,49 @@ +#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" + +namespace { +LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("printf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + // int printf(char* format, ...) + SmallVector argsType{ptr_ty(context)}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(context), funcName, + funcType); +} +} // namespace + +namespace mlir::triton::cpu { + +Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter, + Location loc, LLVM::LLVMFuncOp funcOp, + int axis) const { + assert(axis >= 0 && axis < 3); + + // program_id for CPU is provided as function arguments. The last three + // arguments are __grid0 to __grid2 of i32. + assert(funcOp && funcOp.getArguments().size() >= 3); + return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis); +} + +void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter, + Value formatStrStart, int /*formatStrByteCount*/, + ValueRange args) const { + auto loc = UnknownLoc::get(rewriter.getContext()); + SmallVector formatStrAndArgs{formatStrStart}; + for (auto arg : args) { + formatStrAndArgs.push_back(arg); + } + call(getPrintfDeclaration(rewriter), formatStrAndArgs); +} +} // namespace mlir::triton::cpu diff --git a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000000..96a1c5d1619f --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,131 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const CPUTargetInfo &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId( + rewriter, loc, op->getParentOfType(), axis); + }; + SmallVector values = {getPid(0), getPid(1), getPid(2)}; + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(values[0]) << ", " + << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) + << ")" << op.getPrefix(); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + if (op.getOperand(i).getType().dyn_cast()) { + llvm_unreachable("Not implemented for tensor types"); + } + + // Only support scalars for now. + assert(elems.size() == 1); + if (i != 0) { + os << ", "; + } + os << getFormatSubstr(elems[0]); + values.push_back(elems[0]); + } + + llPrintf(formatStr, values, rewriter); + rewriter.eraseOp(op); + return success(); + } + + // TODO: This code is the same as the GPU-backend code. Consider refactoring. + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt) const { + Type type = value.getType(); + if (type.isa()) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } else if (hex) { + prefix += "0"; + prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isSignedInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "lli"; + else + return prefix + "i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "llu"; + else + return prefix + "u"; + } + assert(false && "not supported type"); + return ""; + } + + Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const CPUTargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::cpu::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const CPUTargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000000..65fef7a7d0d5 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,39 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const CPUTargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId( + rewriter, op->getLoc(), op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const CPUTargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::cpu::populateSPMDOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const CPUTargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp index 28d320df32d3..cb15f87ee206 100644 --- a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp @@ -89,10 +89,15 @@ struct ConvertTritonCPUToLLVM } RewritePatternSet patterns(context); + mlir::triton::cpu::CPUTargetInfo targetInfo; int benefit = mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions; mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter, patterns, benefit); + mlir::triton::cpu::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::cpu::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d04f516e8152..82aec76aa2c1 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -607,6 +607,13 @@ def run(self, *args, grid, warmup, **kwargs): configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) constant_params = configs[0].get_constants() + # The CPU launcher will provide the grid ids directly to the kernel. + # Note that this design is interim and subject to change. + if target[0] == 'cpu': + signature["__grid0"] = 'i32' + signature["__grid1"] = 'i32' + signature["__grid2"] = 'i32' + constants = { p.name: v for (v, p) in zip(bound_vals, self.params)